v4.1 release update v2. (#2481)
This commit is contained in:
@ -90,18 +90,17 @@ If you already know the TV layout you want to use for your tiled copy, CuTe DSL
|
||||
# Tile input tensor to thread blocks: ((TileM,TileN),(RestM,RestN))
|
||||
gA = cute.zipped_divide(mA, tiler_mn)
|
||||
|
||||
where `tiler_mn` is the tile size per thread block and `tv_layout` is the TV layout which maps
|
||||
thread index and inter-thread index of data array per thread to logical coordinates of elements in
|
||||
input and output tensors.
|
||||
|
||||
Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy` utility.
|
||||
Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy_tv` utility, which
|
||||
infers the tiler and tv layout for the tiled copy automatically, where `tiler` is the tile size per thread
|
||||
block and `tv_layout` is the TV layout which maps thread index and inter-thread index of data array per
|
||||
thread to logical coordinates of elements in input and output tensors.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
blkA = gA[((None, None), bidx)] # (TileM,TileN)
|
||||
|
||||
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
|
||||
tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
|
||||
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
|
||||
# get slice of tiled_copy_A for current thread
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
@ -140,8 +139,8 @@ def elementwise_add_kernel(
|
||||
gC: cute.Tensor,
|
||||
cC: cute.Tensor, # coordinate tensor
|
||||
shape: cute.Shape,
|
||||
tv_layout: cute.Layout,
|
||||
tiler_mn: cute.Shape,
|
||||
thr_layout: cute.Layout,
|
||||
val_layout: cute.Layout,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
@ -165,9 +164,9 @@ def elementwise_add_kernel(
|
||||
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
|
||||
copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type)
|
||||
|
||||
tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
|
||||
tiled_copy_B = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
|
||||
tiled_copy_C = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn)
|
||||
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout)
|
||||
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
@ -254,7 +253,7 @@ def elementwise_add(mA, mB, mC, copy_bits: cutlass.Constexpr = 128):
|
||||
cC = cute.zipped_divide(idC, tiler=tiler_mn)
|
||||
print(f"[DSL INFO] coord tensor = {cC.type}")
|
||||
|
||||
elementwise_add_kernel(gA, gB, gC, cC, mC.shape, tv_layout, tiler_mn).launch(
|
||||
elementwise_add_kernel(gA, gB, gC, cC, mC.shape, thr_layout, val_layout).launch(
|
||||
grid=[cute.size(gC, mode=[1]), 1, 1],
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
)
|
||||
@ -362,7 +361,7 @@ def run_elementwise_add(
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=10,
|
||||
warmup_iterations=warmup_iterations,
|
||||
profiling_iterations=iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
# Print execution results
|
||||
|
||||
@ -353,7 +353,7 @@ def run_elementwise_apply_and_verify(
|
||||
current_stream,
|
||||
),
|
||||
warmup_iterations=warmup_iterations,
|
||||
profiling_iterations=iterations,
|
||||
iterations=iterations,
|
||||
use_cuda_graphs=True,
|
||||
stream=current_stream,
|
||||
)
|
||||
|
||||
@ -32,13 +32,13 @@ from typing import Type, Union, Callable
|
||||
|
||||
import torch
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import cpasync, warp
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.utils.ampere_helpers as sm80_utils
|
||||
import cutlass.utils as utils
|
||||
|
||||
"""
|
||||
A flash attention v2 forward pass example for NVIDIA Ampere SM80 architecture using CUTE DSL.
|
||||
@ -163,7 +163,7 @@ class FlashAttentionForwardAmpere:
|
||||
# Check if block size setting is out of shared memory capacity
|
||||
# Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size
|
||||
smem_usage = (m_block_size * head_dim + n_block_size * head_dim * 2) * 2
|
||||
smem_capacity = sm80_utils.SMEM_CAPACITY["sm80"]
|
||||
smem_capacity = utils.get_smem_capacity_in_bytes("sm_80")
|
||||
if smem_usage > smem_capacity:
|
||||
return False
|
||||
|
||||
@ -469,21 +469,9 @@ class FlashAttentionForwardAmpere:
|
||||
warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4),
|
||||
self._dtype,
|
||||
)
|
||||
smem_tiled_copy_Q = cute.make_tiled_copy(
|
||||
smem_copy_atom_Q,
|
||||
layout_tv=tiled_mma.tv_layout_A_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
smem_tiled_copy_K = cute.make_tiled_copy(
|
||||
smem_copy_atom_K,
|
||||
layout_tv=tiled_mma.tv_layout_B_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
smem_tiled_copy_V = cute.make_tiled_copy(
|
||||
smem_copy_atom_V,
|
||||
layout_tv=tiled_mma.tv_layout_B_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
smem_tiled_copy_Q = cute.make_tiled_copy_A(smem_copy_atom_Q, tiled_mma)
|
||||
smem_tiled_copy_K = cute.make_tiled_copy_B(smem_copy_atom_K, tiled_mma)
|
||||
smem_tiled_copy_V = cute.make_tiled_copy_B(smem_copy_atom_V, tiled_mma)
|
||||
|
||||
smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx)
|
||||
smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx)
|
||||
@ -702,11 +690,7 @@ class FlashAttentionForwardAmpere:
|
||||
cute.nvgpu.CopyUniversalOp(), self._dtype
|
||||
)
|
||||
# tiled copy atom for O
|
||||
smem_tiled_copy_O = cute.make_tiled_copy(
|
||||
smem_copy_atom_O,
|
||||
layout_tv=tiled_mma.tv_layout_C_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)),
|
||||
)
|
||||
smem_tiled_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma)
|
||||
smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx)
|
||||
taccOrO = smem_thr_copy_O.retile(rO)
|
||||
taccOsO = smem_thr_copy_O.partition_D(sO)
|
||||
@ -1178,7 +1162,7 @@ class FlashAttentionForwardAmpere:
|
||||
return cute.arch.exp2(x)
|
||||
|
||||
|
||||
def run_flash_attention_fwd(
|
||||
def run(
|
||||
dtype: Type[cutlass.Numeric],
|
||||
batch_size: int,
|
||||
seqlen_q: int,
|
||||
@ -1193,6 +1177,8 @@ def run_flash_attention_fwd(
|
||||
warmup_iterations: int = 0,
|
||||
iterations: int = 1,
|
||||
skip_ref_check: bool = False,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Skip unsupported testcase
|
||||
if not FlashAttentionForwardAmpere.can_implement(
|
||||
@ -1207,6 +1193,23 @@ def run_flash_attention_fwd(
|
||||
f"Unsupported testcase {dtype}, {head_dim}, {m_block_size}, {n_block_size}, {num_threads}, {is_causal}"
|
||||
)
|
||||
|
||||
print(f"Running Ampere SM80 FlashAttentionForward test with:")
|
||||
print(f" dtype: {dtype}")
|
||||
print(f" batch_size: {batch_size}")
|
||||
print(f" seqlen_q: {seqlen_q}")
|
||||
print(f" seqlen_k: {seqlen_k}")
|
||||
print(f" num_head: {num_head}")
|
||||
print(f" head_dim: {head_dim}")
|
||||
print(f" softmax_scale: {softmax_scale}")
|
||||
print(f" m_block_size: {m_block_size}")
|
||||
print(f" n_block_size: {n_block_size}")
|
||||
print(f" num_threads: {num_threads}")
|
||||
print(f" is_causal: {is_causal}")
|
||||
print(f" warmup_iterations: {warmup_iterations}")
|
||||
print(f" iterations: {iterations}")
|
||||
print(f" skip_ref_check: {skip_ref_check}")
|
||||
print(f" use_cold_l2: {use_cold_l2}")
|
||||
|
||||
# Create tensor Q/K/V/O
|
||||
def create_tensor(
|
||||
batch_size: int,
|
||||
@ -1217,22 +1220,28 @@ def run_flash_attention_fwd(
|
||||
) -> cute.Tensor:
|
||||
# (batch_size, seqlen, num_head, head_dim)
|
||||
shape = (batch_size, seqlen, num_head, head_dim)
|
||||
return (
|
||||
torch.empty(*shape, dtype=torch.int32).random_(-2, 2).to(dtype=dtype).cuda()
|
||||
torch_tensor = (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
.to(dtype=cutlass_torch.dtype(dtype))
|
||||
.cuda()
|
||||
)
|
||||
# assume input is 16B aligned.
|
||||
cute_tensor = (
|
||||
from_dlpack(torch_tensor, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=3)
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=3,
|
||||
stride_order=torch_tensor.dim_order(),
|
||||
divisibility=(128 // dtype.width),
|
||||
)
|
||||
)
|
||||
return cute_tensor, torch_tensor
|
||||
|
||||
q = create_tensor(
|
||||
batch_size, seqlen_q, num_head, head_dim, cutlass_torch.dtype(dtype)
|
||||
)
|
||||
k = create_tensor(
|
||||
batch_size, seqlen_k, num_head, head_dim, cutlass_torch.dtype(dtype)
|
||||
)
|
||||
v = create_tensor(
|
||||
batch_size, seqlen_k, num_head, head_dim, cutlass_torch.dtype(dtype)
|
||||
)
|
||||
o = create_tensor(
|
||||
batch_size, seqlen_q, num_head, head_dim, cutlass_torch.dtype(dtype)
|
||||
)
|
||||
q, q_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
|
||||
k, k_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
|
||||
v, v_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
|
||||
o, o_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
|
||||
|
||||
fa2_fwd = FlashAttentionForwardAmpere(
|
||||
head_dim,
|
||||
@ -1241,78 +1250,63 @@ def run_flash_attention_fwd(
|
||||
num_threads,
|
||||
is_causal,
|
||||
)
|
||||
# assume input is 16B align.
|
||||
q_tensor = (
|
||||
from_dlpack(q, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=3)
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=3, stride_order=q.dim_order(), divisibility=(128 // dtype.width)
|
||||
)
|
||||
)
|
||||
k_tensor = (
|
||||
from_dlpack(k, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=3)
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=3, stride_order=k.dim_order(), divisibility=(128 // dtype.width)
|
||||
)
|
||||
)
|
||||
v_tensor = (
|
||||
from_dlpack(v, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=3)
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=3, stride_order=v.dim_order(), divisibility=(128 // dtype.width)
|
||||
)
|
||||
)
|
||||
o_tensor = (
|
||||
from_dlpack(o, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=3)
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=3, stride_order=o.dim_order(), divisibility=(128 // dtype.width)
|
||||
)
|
||||
)
|
||||
|
||||
# Get current CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
# compile the fa2 forward pass
|
||||
compiled_fa2_fwd = cute.compile(
|
||||
fa2_fwd, q_tensor, k_tensor, v_tensor, o_tensor, softmax_scale, current_stream
|
||||
compiled_fa2_fwd = cute.compile(fa2_fwd, q, k, v, o, softmax_scale, current_stream)
|
||||
|
||||
if not skip_ref_check:
|
||||
compiled_fa2_fwd(q, k, v, o, softmax_scale, current_stream)
|
||||
torch.cuda.synchronize()
|
||||
q_ref = q_torch.permute(0, 2, 1, 3)
|
||||
k_ref = k_torch.permute(0, 2, 1, 3)
|
||||
v_ref = v_torch.permute(0, 2, 1, 3)
|
||||
torch.backends.cuda.enable_flash_sdp(enabled=True)
|
||||
ref_o = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal
|
||||
).permute(0, 2, 1, 3)
|
||||
torch.testing.assert_close(o_torch.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04)
|
||||
print("Results verified successfully!")
|
||||
|
||||
def generate_tensors():
|
||||
q_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
|
||||
k_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
|
||||
v_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
|
||||
o_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
|
||||
return testing.JitArguments(
|
||||
q_workspace,
|
||||
k_workspace,
|
||||
v_workspace,
|
||||
o_workspace,
|
||||
softmax_scale,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
q_torch.numel() * q_torch.element_size()
|
||||
+ k_torch.numel() * k_torch.element_size()
|
||||
+ v_torch.numel() * v_torch.element_size()
|
||||
+ o_torch.numel() * o_torch.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
compiled_fa2_fwd,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
# warmup
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_fa2_fwd(
|
||||
q_tensor,
|
||||
k_tensor,
|
||||
v_tensor,
|
||||
o_tensor,
|
||||
softmax_scale,
|
||||
current_stream,
|
||||
)
|
||||
# run the compiled fa2 forward pass
|
||||
for _ in range(iterations):
|
||||
compiled_fa2_fwd(
|
||||
q_tensor,
|
||||
k_tensor,
|
||||
v_tensor,
|
||||
o_tensor,
|
||||
softmax_scale,
|
||||
current_stream,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if skip_ref_check:
|
||||
return
|
||||
# reference implementation
|
||||
q_ref = q.permute(0, 2, 1, 3)
|
||||
k_ref = k.permute(0, 2, 1, 3)
|
||||
v_ref = v.permute(0, 2, 1, 3)
|
||||
torch.backends.cuda.enable_flash_sdp(enabled=True)
|
||||
ref_o = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal
|
||||
).permute(0, 2, 1, 3)
|
||||
|
||||
torch.testing.assert_close(o.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04)
|
||||
|
||||
return avg_time_us # Return execution time in microseconds
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -1334,9 +1328,15 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference check"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
run_flash_attention_fwd(
|
||||
run(
|
||||
args.dtype,
|
||||
args.batch_size,
|
||||
args.seqlen_q,
|
||||
@ -1348,6 +1348,10 @@ if __name__ == "__main__":
|
||||
args.n_block_size,
|
||||
args.num_threads,
|
||||
args.is_causal,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
|
||||
print("PASS")
|
||||
|
||||
@ -634,16 +634,50 @@ class SGemm:
|
||||
return
|
||||
|
||||
|
||||
def main(
|
||||
def run(
|
||||
mnk: Tuple[int, int, int],
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
problem_shape: Tuple[int, int, int],
|
||||
static_shape: bool = False,
|
||||
warmup_iterations: int = 2,
|
||||
iterations: int = 100,
|
||||
skip_ref_check: bool = False,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
M, N, K = problem_shape
|
||||
"""Execute SIMT GEMM operation and benchmark performance.
|
||||
|
||||
:param mnk: GEMM problem size (M, N, K, L)
|
||||
:type mnk: Tuple[int, int, int, int]
|
||||
:param a_major: Memory layout of tensor A
|
||||
:type a_major: str
|
||||
:param b_major: Memory layout of tensor B
|
||||
:type b_major: str
|
||||
:param c_major: Memory layout of tensor C
|
||||
:type c_major: str
|
||||
:param static_shape: Whether to use static shape optimization, defaults to False
|
||||
:type static_shape: bool, optional
|
||||
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 2
|
||||
:type warmup_iterations: int, optional
|
||||
:param iterations: Number of benchmark iterations to run, defaults to 100
|
||||
:type iterations: int, optional
|
||||
:param skip_ref_check: Skip validation against reference implementation, defaults to False
|
||||
:type skip_ref_check: bool, optional
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||||
:type use_cold_l2: bool, optional
|
||||
:return: Execution time of the GEMM kernel in microseconds
|
||||
:rtype: float
|
||||
"""
|
||||
print(f"Running Ampere SIMT GEMM example:")
|
||||
print(f"mnk: {mnk}")
|
||||
print(f"A major: {a_major}, B major: {b_major}, C major: {c_major}")
|
||||
print(f"Static shape: {static_shape}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {use_cold_l2}")
|
||||
M, N, K = mnk
|
||||
|
||||
# Create and permute tensor A/B/C
|
||||
def create_and_permute_tensor(mode0, mode1, is_mode0_major, dtype):
|
||||
@ -710,20 +744,6 @@ def main(
|
||||
|
||||
print("Executing GEMM kernel...")
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
gemm,
|
||||
kernel_arguments=testing.JitArguments(
|
||||
a_tensor, b_tensor, c_tensor, current_stream
|
||||
),
|
||||
warmup_iterations=warmup_iterations,
|
||||
profiling_iterations=iterations,
|
||||
use_cuda_graphs=False,
|
||||
stream=current_stream,
|
||||
)
|
||||
|
||||
# Print execution results
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
|
||||
if not skip_ref_check:
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
torch.cuda.synchronize()
|
||||
@ -732,6 +752,71 @@ def main(
|
||||
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
def generate_tensors():
|
||||
# Create new tensors for each workspace to ensure cold L2 cache
|
||||
a_workspace = create_and_permute_tensor(M, K, a_major == "m", torch.float32)
|
||||
b_workspace = create_and_permute_tensor(N, K, b_major == "n", torch.float32)
|
||||
c_workspace = create_and_permute_tensor(M, N, c_major == "m", torch.float32)
|
||||
|
||||
if static_shape:
|
||||
a_tensor_workspace = (
|
||||
from_dlpack(a_workspace, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if a_major == "k" else 0),
|
||||
divisibility=divisibility_a,
|
||||
)
|
||||
)
|
||||
else:
|
||||
a_tensor_workspace = from_dlpack(a_workspace, assumed_align=16)
|
||||
|
||||
b_tensor_workspace = (
|
||||
from_dlpack(b_workspace, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if b_major == "k" else 0),
|
||||
divisibility=divisibility_b,
|
||||
)
|
||||
)
|
||||
|
||||
c_tensor_workspace = (
|
||||
from_dlpack(c_workspace, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if c_major == "n" else 0),
|
||||
divisibility=divisibility_c,
|
||||
)
|
||||
)
|
||||
|
||||
return testing.JitArguments(
|
||||
a_tensor_workspace, b_tensor_workspace, c_tensor_workspace, current_stream
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
a.numel() * a.element_size()
|
||||
+ b.numel() * b.element_size()
|
||||
+ c.numel() * c.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
# Print execution results
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
|
||||
return avg_time_us # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -753,19 +838,27 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
print("Running SIMT GEMM example:")
|
||||
|
||||
torch.manual_seed(1024)
|
||||
|
||||
main(
|
||||
run(
|
||||
args.mnk,
|
||||
args.a_major,
|
||||
args.b_major,
|
||||
args.c_major,
|
||||
args.mnk,
|
||||
args.static_shape,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
@ -51,7 +51,7 @@ This GEMM kernel supports the following features:
|
||||
- Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations
|
||||
- Threadblock rasterization to improve data re-use
|
||||
- Supports multi-stage pipeline to overlap computation and memory access
|
||||
- Implements shared memory buffering for epilogue to increase coalesed global memory access
|
||||
- Implements shared memory buffering for epilogue to increase coalesced global memory access
|
||||
|
||||
This GEMM works as follows:
|
||||
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies.
|
||||
@ -214,7 +214,7 @@ class TensorOpGemm:
|
||||
atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits
|
||||
)
|
||||
|
||||
# Creates a synchonous copy atom and thread layouts for the epilogue
|
||||
# Creates a synchronous copy atom and thread layouts for the epilogue
|
||||
c_copy_bits = 128
|
||||
atom_sync_copy = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
@ -550,16 +550,8 @@ class TensorOpGemm:
|
||||
|
||||
# Creates the tiled copy so that it matches the thread-value layout
|
||||
# expected by the tiled mma
|
||||
tiled_copy_s2r_A = cute.make_tiled_copy(
|
||||
atom_copy_s2r_A,
|
||||
layout_tv=tiled_mma.tv_layout_A_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
tiled_copy_s2r_B = cute.make_tiled_copy(
|
||||
atom_copy_s2r_B,
|
||||
layout_tv=tiled_mma.tv_layout_B_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma)
|
||||
tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma)
|
||||
|
||||
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
|
||||
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
|
||||
@ -836,8 +828,7 @@ class TensorOpGemm:
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((copy_elems, 1))
|
||||
)
|
||||
tiler_mn, layout_tv = cute.make_layout_tv(thread_layout, value_layout)
|
||||
return cute.make_tiled_copy(atom_copy, layout_tv, tiler_mn)
|
||||
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
|
||||
|
||||
def raster_tile(self, i, j, f):
|
||||
new_i = i // f
|
||||
@ -845,20 +836,33 @@ class TensorOpGemm:
|
||||
return (new_i, new_j)
|
||||
|
||||
|
||||
def run_tensor_op_gemm(
|
||||
def run(
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
acc_dtype: Type[cutlass.Numeric],
|
||||
problem_shape: Tuple[int, int, int, int],
|
||||
mnkl: Tuple[int, int, int, int],
|
||||
atom_layout_mnk: Tuple[int, int, int],
|
||||
warmup_iterations: int = 2,
|
||||
iterations: int = 100,
|
||||
skip_ref_check: bool = False,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
M, N, K, L = problem_shape
|
||||
print(f"Running Ampere tensor core GEMM example:")
|
||||
print(f"mnkl: {mnkl}")
|
||||
print(
|
||||
f"A dtype: {ab_dtype}, B dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}"
|
||||
)
|
||||
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
||||
print(f"Atoms layout: {atom_layout_mnk}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {use_cold_l2}")
|
||||
M, N, K, L = mnkl
|
||||
|
||||
# Create and permute tensor A/B/C
|
||||
def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype):
|
||||
@ -866,23 +870,28 @@ def run_tensor_op_gemm(
|
||||
# else: (l, mode0, mode1) -> (mode0, mode1, l)
|
||||
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
||||
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
||||
|
||||
return (
|
||||
torch_tensor = (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
.to(dtype=dtype)
|
||||
.to(dtype=cutlass_torch.dtype(dtype))
|
||||
.permute(permute_order)
|
||||
.cuda()
|
||||
)
|
||||
# assume input is 16B aligned
|
||||
cute_tensor = (
|
||||
from_dlpack(torch_tensor, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if not is_mode0_major else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if not is_mode0_major else 0),
|
||||
stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0),
|
||||
divisibility=(128 // dtype.width),
|
||||
)
|
||||
)
|
||||
return cute_tensor, torch_tensor
|
||||
|
||||
a = create_and_permute_tensor(
|
||||
L, M, K, a_major == "m", cutlass_torch.dtype(ab_dtype)
|
||||
)
|
||||
b = create_and_permute_tensor(
|
||||
L, N, K, b_major == "n", cutlass_torch.dtype(ab_dtype)
|
||||
)
|
||||
c = create_and_permute_tensor(L, M, N, c_major == "m", cutlass_torch.dtype(c_dtype))
|
||||
ref = torch.einsum("mkl,nkl->mnl", a, b).to(cutlass_torch.dtype(c_dtype))
|
||||
mA, a_torch = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
|
||||
mB, b_torch = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
|
||||
mC, c_torch = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
|
||||
|
||||
tensor_op_gemm = TensorOpGemm(
|
||||
ab_dtype,
|
||||
@ -891,56 +900,49 @@ def run_tensor_op_gemm(
|
||||
atom_layout_mnk,
|
||||
)
|
||||
|
||||
# assume input is 16B aligned
|
||||
a_tensor = (
|
||||
from_dlpack(a, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if a_major == "k" else 0),
|
||||
stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0),
|
||||
divisibility=(128 // ab_dtype.width),
|
||||
)
|
||||
)
|
||||
b_tensor = (
|
||||
from_dlpack(b, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if b_major == "k" else 0),
|
||||
stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0),
|
||||
divisibility=(128 // ab_dtype.width),
|
||||
)
|
||||
)
|
||||
c_tensor = (
|
||||
from_dlpack(c, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if c_major == "n" else 0),
|
||||
stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0),
|
||||
divisibility=(128 // c_dtype.width),
|
||||
)
|
||||
)
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
gemm = cute.compile(tensor_op_gemm, a_tensor, b_tensor, c_tensor)
|
||||
compiled_gemm = cute.compile(tensor_op_gemm, mA, mB, mC)
|
||||
|
||||
print("Executing GEMM kernel...")
|
||||
|
||||
if not skip_ref_check:
|
||||
ref = torch.einsum(
|
||||
"mkl,nkl->mnl",
|
||||
a_torch.to(dtype=torch.float32),
|
||||
b_torch.to(dtype=torch.float32),
|
||||
).to(cutlass_torch.dtype(c_dtype))
|
||||
compiled_gemm(mA, mB, mC)
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(c_torch.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
def generate_tensors():
|
||||
a_workspace, _ = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
|
||||
b_workspace, _ = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
|
||||
c_workspace, _ = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
|
||||
return testing.JitArguments(a_workspace, b_workspace, c_workspace)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
a_torch.numel() * a_torch.element_size()
|
||||
+ b_torch.numel() * b_torch.element_size()
|
||||
+ c_torch.numel() * c_torch.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
gemm,
|
||||
kernel_arguments=testing.JitArguments(a_tensor, b_tensor, c_tensor),
|
||||
compiled_gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
warmup_iterations=warmup_iterations,
|
||||
profiling_iterations=iterations,
|
||||
iterations=iterations,
|
||||
use_cuda_graphs=False,
|
||||
)
|
||||
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
|
||||
if not skip_ref_check:
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
return avg_time_us # Return execution time in microseconds
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -985,10 +987,15 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
print("Running Ampere tensor core GEMM example:")
|
||||
run_tensor_op_gemm(
|
||||
run(
|
||||
args.a_major,
|
||||
args.b_major,
|
||||
args.c_major,
|
||||
@ -1000,5 +1007,6 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -212,7 +212,7 @@ class DenseGemmKernel:
|
||||
|
||||
self.occupancy = 1
|
||||
self.threads_per_cta = 128
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
def _setup_attributes(self):
|
||||
"""Set up configurations that are dependent on GEMM inputs
|
||||
@ -1106,11 +1106,7 @@ class DenseGemmKernel:
|
||||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy(
|
||||
copy_atom_r2s,
|
||||
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_copy_t2r.tiler_mn,
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
@ -1772,7 +1768,7 @@ def run_dense_gemm(
|
||||
ref_c = ref
|
||||
elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
|
||||
# m major: (l, n, m) -> (m, n, l)
|
||||
# k major: (l, m, n) -> (m, n, l)
|
||||
# n major: (l, m, n) -> (m, n, l)
|
||||
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
|
||||
shape = (l, m, n) if c_major == "n" else (l, n, m)
|
||||
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
||||
|
||||
@ -38,6 +38,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
@ -226,7 +227,7 @@ class PersistentDenseGemmKernel:
|
||||
self.cta_sync_bar_id = 0
|
||||
self.epilog_sync_bar_id = 1
|
||||
self.tmem_ptr_sync_bar_id = 2
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
def _setup_attributes(self):
|
||||
"""Set up configurations that are dependent on GEMM inputs
|
||||
@ -1308,11 +1309,7 @@ class PersistentDenseGemmKernel:
|
||||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy(
|
||||
copy_atom_r2s,
|
||||
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_copy_t2r.tiler_mn,
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
@ -1824,7 +1821,7 @@ class PersistentDenseGemmKernel:
|
||||
return can_implement
|
||||
|
||||
|
||||
def run_dense_gemm(
|
||||
def run(
|
||||
mnkl: Tuple[int, int, int, int],
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
@ -1832,17 +1829,58 @@ def run_dense_gemm(
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
mma_tiler_mn: Tuple[int, int],
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
use_2cta_instrs: bool,
|
||||
use_tma_store: bool,
|
||||
tolerance: float,
|
||||
mma_tiler_mn: Tuple[int, int] = (256, 256),
|
||||
cluster_shape_mn: Tuple[int, int] = (2, 1),
|
||||
use_2cta_instrs: bool = True,
|
||||
use_tma_store: bool = True,
|
||||
tolerance: float = 1e-01,
|
||||
warmup_iterations: int = 0,
|
||||
iterations: int = 1,
|
||||
skip_ref_check: bool = False,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
|
||||
"""Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking.
|
||||
|
||||
This function prepares input tensors, configures and launches the persistent GEMM kernel,
|
||||
optionally performs reference validation, and benchmarks the execution performance.
|
||||
|
||||
:param mnkl: Problem size (M, N, K, L)
|
||||
:type mnkl: Tuple[int, int, int, int]
|
||||
:param ab_dtype: Data type for input tensors A and B
|
||||
:type ab_dtype: Type[cutlass.Numeric]
|
||||
:param c_dtype: Data type for output tensor C
|
||||
:type c_dtype: Type[cutlass.Numeric]
|
||||
:param acc_dtype: Data type for accumulation during matrix multiplication
|
||||
:type acc_dtype: Type[cutlass.Numeric]
|
||||
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
|
||||
:type a_major/b_major/c_major: str
|
||||
:param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the
|
||||
default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type mma_tiler_mn: Tuple[int, int], optional
|
||||
:param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
|
||||
default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type cluster_shape_mn: Tuple[int, int], optional
|
||||
:param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
|
||||
will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type use_2cta_instrs: bool, optional
|
||||
:param use_tma_store: Whether to use TMA store. If not specified in the decorator parameters, the autotuner will use
|
||||
the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type use_tma_store: bool, optional
|
||||
:param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
|
||||
:type tolerance: float, optional
|
||||
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
||||
:type warmup_iterations: int, optional
|
||||
:param iterations: Number of benchmark iterations to run, defaults to 1
|
||||
:type iterations: int, optional
|
||||
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
||||
:type skip_ref_check: bool, optional
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||||
:type use_cold_l2: bool, optional
|
||||
:raises RuntimeError: If CUDA GPU is not available
|
||||
:raises ValueError: If the configuration is invalid or unsupported by the kernel
|
||||
:return: Execution time of the GEMM kernel
|
||||
:rtype: float
|
||||
"""
|
||||
print(f"Running Blackwell Persistent Dense GEMM test with:")
|
||||
print(f"mnkl: {mnkl}")
|
||||
@ -1855,6 +1893,7 @@ def run_dense_gemm(
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||||
|
||||
# Unpack parameters
|
||||
m, n, k, l = mnkl
|
||||
@ -1931,15 +1970,15 @@ def run_dense_gemm(
|
||||
is_dynamic_layout=is_dynamic_layout,
|
||||
)
|
||||
|
||||
return f32_torch_tensor, cute_tensor, torch_tensor
|
||||
return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu
|
||||
|
||||
a_ref, a_tensor, a_torch = create_and_permute_tensor(
|
||||
a_ref, a_tensor, a_torch, a_torch_cpu = create_and_permute_tensor(
|
||||
l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True
|
||||
)
|
||||
b_ref, b_tensor, b_torch = create_and_permute_tensor(
|
||||
b_ref, b_tensor, b_torch, b_torch_cpu = create_and_permute_tensor(
|
||||
l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True
|
||||
)
|
||||
c_ref, c_tensor, c_torch = create_and_permute_tensor(
|
||||
c_ref, c_tensor, c_torch, c_torch_cpu = create_and_permute_tensor(
|
||||
l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True
|
||||
)
|
||||
|
||||
@ -1967,16 +2006,8 @@ def run_dense_gemm(
|
||||
gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, current_stream
|
||||
)
|
||||
|
||||
# Launch GPU kernel
|
||||
# Warm up
|
||||
for i in range(warmup_iterations):
|
||||
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
# Execution
|
||||
for i in range(iterations):
|
||||
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
|
||||
# Compute reference result
|
||||
if not skip_ref_check:
|
||||
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
if ab_dtype in {
|
||||
cutlass.Int8,
|
||||
cutlass.Uint8,
|
||||
@ -2028,6 +2059,40 @@ def run_dense_gemm(
|
||||
rtol=1e-05,
|
||||
)
|
||||
|
||||
def generate_tensors():
|
||||
a_tensor, _ = cutlass_torch.cute_tensor_like(
|
||||
a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
|
||||
)
|
||||
b_tensor, _ = cutlass_torch.cute_tensor_like(
|
||||
b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
|
||||
)
|
||||
c_tensor, _ = cutlass_torch.cute_tensor_like(
|
||||
c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16
|
||||
)
|
||||
return testing.JitArguments(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
a_torch_cpu.numel() * a_torch_cpu.element_size()
|
||||
+ b_torch_cpu.numel() * b_torch_cpu.element_size()
|
||||
+ c_torch_cpu.numel() * c_torch_cpu.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -2090,6 +2155,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -2102,7 +2173,7 @@ if __name__ == "__main__":
|
||||
if len(args.cluster_shape_mn) != 2:
|
||||
parser.error("--cluster_shape_mn must contain exactly 2 values")
|
||||
|
||||
run_dense_gemm(
|
||||
run(
|
||||
args.mnkl,
|
||||
args.ab_dtype,
|
||||
args.c_dtype,
|
||||
@ -2118,5 +2189,6 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
@ -223,7 +223,7 @@ class DenseGemmKernel:
|
||||
|
||||
self.occupancy = 1
|
||||
self.threads_per_cta = 128
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
def _setup_attributes(self):
|
||||
"""Set up configurations that are dependent on GEMM inputs
|
||||
@ -1063,11 +1063,7 @@ class DenseGemmKernel:
|
||||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy(
|
||||
copy_atom_r2s,
|
||||
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_copy_t2r.tiler_mn,
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
|
||||
@ -43,6 +43,7 @@ import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.cute.testing as testing
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass.cute.typing import Int32, Int64, Float32, Boolean
|
||||
|
||||
@ -90,7 +91,7 @@ Constraints for this example:
|
||||
* Number of heads in Q must be divisible by number of heads in K
|
||||
* mma_tiler_mn must be 128,128
|
||||
* Batch size must be the same for Q, K, and V tensors
|
||||
* For causal masking, use --has_casual_mask (note: specify without =True/False)
|
||||
* For causal masking, use --is_causal (note: specify without =True/False)
|
||||
* For persistent scheduling, use --is_persistent (note: specify without =True/False)
|
||||
"""
|
||||
|
||||
@ -2373,11 +2374,7 @@ class BlackwellFusedMultiHeadAttentionForward:
|
||||
smem_copy_atom = sm100_utils.get_smem_store_op(
|
||||
self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load
|
||||
)
|
||||
tiled_smem_store = cute.make_tiled_copy(
|
||||
smem_copy_atom,
|
||||
layout_tv=tiled_tmem_load.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_tmem_load.tiler_mn,
|
||||
)
|
||||
tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load)
|
||||
|
||||
tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None])
|
||||
tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None])
|
||||
@ -2619,7 +2616,7 @@ class BlackwellFusedMultiHeadAttentionForward:
|
||||
return tile_sched_params, grid
|
||||
|
||||
|
||||
def run_fmha_and_verify(
|
||||
def run(
|
||||
q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
|
||||
k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
|
||||
in_dtype: Type[cutlass.Numeric],
|
||||
@ -2628,7 +2625,7 @@ def run_fmha_and_verify(
|
||||
pv_acc_dtype: Type[cutlass.Numeric],
|
||||
mma_tiler_mn: Tuple[int, int],
|
||||
is_persistent: bool,
|
||||
has_casual_mask: bool,
|
||||
is_causal: bool,
|
||||
scale_q: float,
|
||||
scale_k: float,
|
||||
scale_v: float,
|
||||
@ -2638,6 +2635,8 @@ def run_fmha_and_verify(
|
||||
warmup_iterations: int,
|
||||
iterations: int,
|
||||
skip_ref_check: bool,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Execute Fused Multi-Head Attention (FMHA) on Blackwell architecture and validate results.
|
||||
|
||||
@ -2670,8 +2669,8 @@ def run_fmha_and_verify(
|
||||
:type mma_tiler_mn: Tuple[int, int]
|
||||
:param is_persistent: Whether to use persistent kernel optimization
|
||||
:type is_persistent: bool
|
||||
:param has_casual_mask: Whether to apply causal masking
|
||||
:type has_casual_mask: bool
|
||||
:param is_causal: Whether to apply causal masking
|
||||
:type is_causal: bool
|
||||
:param scale_q: Scaling factor for query tensor
|
||||
:type scale_q: float
|
||||
:param scale_k: Scaling factor for key tensor
|
||||
@ -2690,9 +2689,13 @@ def run_fmha_and_verify(
|
||||
:type iterations: int
|
||||
:param skip_ref_check: Skip validation against reference implementation
|
||||
:type skip_ref_check: bool
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache
|
||||
:type use_cold_l2: bool
|
||||
|
||||
:raises ValueError: If input shapes are incompatible or head dimension is unsupported
|
||||
:raises RuntimeError: If GPU is unavailable for computation
|
||||
:return: Execution time of the FMHA kernel in microseconds
|
||||
:rtype: float
|
||||
"""
|
||||
|
||||
print(f"Running Blackwell SM100 FMHA test with:")
|
||||
@ -2704,13 +2707,17 @@ def run_fmha_and_verify(
|
||||
print(f" pv_acc_dtype: {pv_acc_dtype}")
|
||||
print(f" mma_tiler_mn: {mma_tiler_mn}")
|
||||
print(f" is_persistent: {is_persistent}")
|
||||
print(f" has_casual_mask: {has_casual_mask}")
|
||||
print(f" is_causal: {is_causal}")
|
||||
print(f" scale_q: {scale_q}")
|
||||
print(f" scale_k: {scale_k}")
|
||||
print(f" scale_v: {scale_v}")
|
||||
print(f" inv_scale_o: {inv_scale_o}")
|
||||
print(f" scale_softmax: {scale_softmax}")
|
||||
print(f" tolerance: {tolerance}")
|
||||
print(f" warmup_iterations: {warmup_iterations}")
|
||||
print(f" iterations: {iterations}")
|
||||
print(f" skip_ref_check: {skip_ref_check}")
|
||||
print(f" use_cold_l2: {use_cold_l2}")
|
||||
|
||||
# Unpack parameters
|
||||
b, s_q, h_q, d = q_shape
|
||||
@ -2882,7 +2889,7 @@ def run_fmha_and_verify(
|
||||
mma_tiler = (*mma_tiler_mn, d)
|
||||
|
||||
mask_type = MaskType.NO_MASK
|
||||
if has_casual_mask:
|
||||
if is_causal:
|
||||
mask_type = MaskType.CAUSAL_MASK
|
||||
else:
|
||||
if isinstance(s_k, tuple):
|
||||
@ -2942,41 +2949,7 @@ def run_fmha_and_verify(
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_fmha(
|
||||
q_tensor.iterator,
|
||||
k_tensor.iterator,
|
||||
v_tensor.iterator,
|
||||
o_tensor.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
# Execute kernel
|
||||
for _ in range(iterations):
|
||||
compiled_fmha(
|
||||
q_tensor.iterator,
|
||||
k_tensor.iterator,
|
||||
v_tensor.iterator,
|
||||
o_tensor.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def run_torch_fmha(
|
||||
q, k, v, scale_softmax=1.0, scale_output=1.0, has_casual_mask=False
|
||||
):
|
||||
def run_torch_fmha(q, k, v, scale_softmax=1.0, scale_output=1.0, is_causal=False):
|
||||
h_q = q.shape[2]
|
||||
h_k = k.shape[2]
|
||||
|
||||
@ -3005,7 +2978,7 @@ def run_fmha_and_verify(
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# For the situation that torch has not supported, we need to handle it manually
|
||||
situation1 = has_casual_mask and (q.is_nested or k.is_nested)
|
||||
situation1 = is_causal and (q.is_nested or k.is_nested)
|
||||
situation2 = (q.is_nested and not k.is_nested) or (
|
||||
not q.is_nested and k.is_nested
|
||||
)
|
||||
@ -3025,8 +2998,9 @@ def run_fmha_and_verify(
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
scale=scale_softmax,
|
||||
is_causal=has_casual_mask,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
ref_i = ref_i.transpose(0, 1) * scale_output
|
||||
ref_list.append(ref_i)
|
||||
if q.is_nested:
|
||||
ref = torch.nested.nested_tensor(ref_list, layout=torch.jagged)
|
||||
@ -3040,15 +3014,28 @@ def run_fmha_and_verify(
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
scale=scale_softmax,
|
||||
is_causal=has_casual_mask,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
ref = ref.transpose(1, 2) * scale_output
|
||||
ref = ref.transpose(1, 2) * scale_output
|
||||
return ref
|
||||
|
||||
if not skip_ref_check:
|
||||
# Execute kernel once for reference checking
|
||||
compiled_fmha(
|
||||
q_tensor.iterator,
|
||||
k_tensor.iterator,
|
||||
v_tensor.iterator,
|
||||
o_tensor.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
print("Verifying results...")
|
||||
o_ref = run_torch_fmha(
|
||||
q_ref, k_ref, v_ref, scale_softmax, scale_output, has_casual_mask
|
||||
q_ref, k_ref, v_ref, scale_softmax, scale_output, is_causal
|
||||
)
|
||||
|
||||
if o_ref.is_nested:
|
||||
@ -3095,6 +3082,76 @@ def run_fmha_and_verify(
|
||||
torch.testing.assert_close(o_result, o_ref, atol=tolerance, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
def generate_tensors():
|
||||
_, q_tensor_workspace, _ = create_and_pad_tensor(
|
||||
qo_shape,
|
||||
qo_padding,
|
||||
in_dtype,
|
||||
s_cumsum=cum_seqlen_q_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
_, k_tensor_workspace, _ = create_and_pad_tensor(
|
||||
kv_shape,
|
||||
kv_padding,
|
||||
in_dtype,
|
||||
s_cumsum=cum_seqlen_k_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
_, v_tensor_workspace, _ = create_and_pad_tensor(
|
||||
kv_shape,
|
||||
kv_padding,
|
||||
in_dtype,
|
||||
s_cumsum=cum_seqlen_k_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
_, o_tensor_workspace, _ = create_and_pad_tensor(
|
||||
qo_shape,
|
||||
qo_padding,
|
||||
out_dtype,
|
||||
s_cumsum=cum_seqlen_q_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
return testing.JitArguments(
|
||||
q_tensor_workspace.iterator,
|
||||
k_tensor_workspace.iterator,
|
||||
v_tensor_workspace.iterator,
|
||||
o_tensor_workspace.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
q_torch_effective = q_torch.values() if q_torch.is_nested else q_torch
|
||||
k_torch_effective = k_torch.values() if k_torch.is_nested else k_torch
|
||||
v_torch_effective = v_torch.values() if v_torch.is_nested else v_torch
|
||||
o_torch_effective = o_torch.values() if o_torch.is_nested else o_torch
|
||||
one_workspace_bytes = (
|
||||
q_torch_effective.numel() * q_torch_effective.element_size()
|
||||
+ k_torch_effective.numel() * k_torch_effective.element_size()
|
||||
+ v_torch_effective.numel() * v_torch_effective.element_size()
|
||||
+ o_torch_effective.numel() * o_torch_effective.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_fmha,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str):
|
||||
@ -3185,7 +3242,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--has_casual_mask",
|
||||
"--is_causal",
|
||||
action="store_true",
|
||||
help="Whether to use casual mask",
|
||||
)
|
||||
@ -3263,6 +3320,13 @@ if __name__ == "__main__":
|
||||
help="Skip reference check",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if len(args.q_shape) != 4:
|
||||
@ -3279,7 +3343,7 @@ if __name__ == "__main__":
|
||||
|
||||
torch.manual_seed(1111)
|
||||
|
||||
run_fmha_and_verify(
|
||||
run(
|
||||
args.q_shape,
|
||||
args.k_shape,
|
||||
args.in_dtype,
|
||||
@ -3288,7 +3352,7 @@ if __name__ == "__main__":
|
||||
args.pv_acc_dtype,
|
||||
args.mma_tiler_mn,
|
||||
args.is_persistent,
|
||||
args.has_casual_mask,
|
||||
args.is_causal,
|
||||
args.scale_q,
|
||||
args.scale_k,
|
||||
args.scale_v,
|
||||
@ -3298,6 +3362,7 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
|
||||
print("PASS")
|
||||
|
||||
@ -36,6 +36,7 @@ import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
@ -157,7 +158,7 @@ class GroupedGemmKernel:
|
||||
self.tmem_ptr_sync_bar_id = 2
|
||||
# Barrier ID used by MMA/TMA warps to signal A/B tensormap initialization completion
|
||||
self.tensormap_ab_init_bar_id = 4
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
self.num_tma_load_bytes = 0
|
||||
|
||||
def _setup_attributes(self):
|
||||
@ -951,7 +952,7 @@ class GroupedGemmKernel:
|
||||
# Specialized MMA warp
|
||||
#
|
||||
if warp_idx == self.mma_warp_id:
|
||||
# initilize tensormap A, B for TMA warp
|
||||
# initialize tensormap A, B for TMA warp
|
||||
if cutlass.const_expr(self.delegate_tensormap_ab_init):
|
||||
tensormap_manager.init_tensormap_from_atom(
|
||||
tma_atom_a, tensormap_a_init_ptr, self.mma_warp_id
|
||||
@ -1540,11 +1541,7 @@ class GroupedGemmKernel:
|
||||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy(
|
||||
copy_atom_r2s,
|
||||
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_copy_t2r.tiler_mn,
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
@ -1815,7 +1812,136 @@ class GroupedGemmKernel:
|
||||
tensor_memory_management_bytes = 12
|
||||
|
||||
|
||||
def run_grouped_gemm(
|
||||
# Create tensor and return the pointer, tensor, and stride
|
||||
def create_tensor_and_stride(
|
||||
l: int,
|
||||
mode0: int,
|
||||
mode1: int,
|
||||
is_mode0_major: bool,
|
||||
dtype: type[cutlass.Numeric],
|
||||
is_dynamic_layout: bool = True,
|
||||
torch_tensor_cpu: torch.Tensor = None,
|
||||
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
|
||||
"""Create a GPU tensor from scratch or based on an existing CPU tensor.
|
||||
|
||||
:param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one.
|
||||
:type torch_tensor_cpu: torch.Tensor, optional
|
||||
"""
|
||||
if torch_tensor_cpu is None:
|
||||
# Create new CPU tensor
|
||||
torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype)
|
||||
|
||||
# Create GPU tensor from CPU tensor (new or existing)
|
||||
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
|
||||
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
|
||||
)
|
||||
return (
|
||||
torch_tensor.data_ptr(),
|
||||
torch_tensor,
|
||||
cute_tensor,
|
||||
torch_tensor_cpu,
|
||||
torch_tensor.stride()[:-1],
|
||||
)
|
||||
|
||||
|
||||
def create_tensors_for_all_groups(
|
||||
problem_sizes_mnkl: List[tuple[int, int, int, int]],
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
torch_fp32_tensors_abc: List[List[torch.Tensor]] = None,
|
||||
) -> tuple[
|
||||
List[List[int]],
|
||||
List[List[torch.Tensor]],
|
||||
List[tuple],
|
||||
List[List[tuple]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
if torch_fp32_tensors_abc is not None and len(torch_fp32_tensors_abc) != len(
|
||||
problem_sizes_mnkl
|
||||
):
|
||||
raise ValueError("torch_fp32_tensors_abc must have one entry per group")
|
||||
|
||||
# Initialize lists to store tensors for all groups
|
||||
new_torch_fp32_tensors_abc = (
|
||||
[] if torch_fp32_tensors_abc is None else torch_fp32_tensors_abc
|
||||
)
|
||||
torch_tensors_abc = []
|
||||
cute_tensors_abc = []
|
||||
strides_abc = []
|
||||
ptrs_abc = []
|
||||
|
||||
# Iterate through all groups and create tensors for each group
|
||||
for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl):
|
||||
# Get existing CPU tensors if available, otherwise None
|
||||
existing_cpu_a = (
|
||||
torch_fp32_tensors_abc[group_idx][0] if torch_fp32_tensors_abc else None
|
||||
)
|
||||
existing_cpu_b = (
|
||||
torch_fp32_tensors_abc[group_idx][1] if torch_fp32_tensors_abc else None
|
||||
)
|
||||
existing_cpu_c = (
|
||||
torch_fp32_tensors_abc[group_idx][2] if torch_fp32_tensors_abc else None
|
||||
)
|
||||
|
||||
# Create tensors (reusing CPU tensors if provided)
|
||||
(
|
||||
ptr_a,
|
||||
torch_tensor_a,
|
||||
cute_tensor_a,
|
||||
tensor_fp32_a,
|
||||
stride_mk_a,
|
||||
) = create_tensor_and_stride(
|
||||
l, m, k, a_major == "m", ab_dtype, torch_tensor_cpu=existing_cpu_a
|
||||
)
|
||||
(
|
||||
ptr_b,
|
||||
torch_tensor_b,
|
||||
cute_tensor_b,
|
||||
tensor_fp32_b,
|
||||
stride_nk_b,
|
||||
) = create_tensor_and_stride(
|
||||
l, n, k, b_major == "n", ab_dtype, torch_tensor_cpu=existing_cpu_b
|
||||
)
|
||||
(
|
||||
ptr_c,
|
||||
torch_tensor_c,
|
||||
cute_tensor_c,
|
||||
tensor_fp32_c,
|
||||
stride_mn_c,
|
||||
) = create_tensor_and_stride(
|
||||
l, m, n, c_major == "m", c_dtype, torch_tensor_cpu=existing_cpu_c
|
||||
)
|
||||
|
||||
# Only append to new_torch_fp32_tensors_abc if we created new CPU tensors
|
||||
if torch_fp32_tensors_abc is None:
|
||||
new_torch_fp32_tensors_abc.append(
|
||||
[tensor_fp32_a, tensor_fp32_b, tensor_fp32_c]
|
||||
)
|
||||
|
||||
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
|
||||
torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
|
||||
strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c])
|
||||
cute_tensors_abc.append(
|
||||
(
|
||||
cute_tensor_a,
|
||||
cute_tensor_b,
|
||||
cute_tensor_c,
|
||||
)
|
||||
)
|
||||
|
||||
return (
|
||||
ptrs_abc,
|
||||
torch_tensors_abc,
|
||||
cute_tensors_abc,
|
||||
strides_abc,
|
||||
new_torch_fp32_tensors_abc,
|
||||
)
|
||||
|
||||
|
||||
def run(
|
||||
num_groups: int,
|
||||
problem_sizes_mnkl: tuple[int, int, int, int],
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
@ -1832,8 +1958,16 @@ def run_grouped_gemm(
|
||||
warmup_iterations: int,
|
||||
iterations: int,
|
||||
skip_ref_check: bool,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Run grouped GEMM example with specified configurations."""
|
||||
"""Run grouped GEMM example with specified configurations.
|
||||
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||||
:type use_cold_l2: bool, optional
|
||||
:return: Execution time of the GEMM kernel in microseconds
|
||||
:rtype: float
|
||||
"""
|
||||
print(f"Running Blackwell Grouped GEMM test with:")
|
||||
print(f"{num_groups} groups")
|
||||
for i, (m, n, k, l) in enumerate(problem_sizes_mnkl):
|
||||
@ -1847,6 +1981,7 @@ def run_grouped_gemm(
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||||
|
||||
# Skip unsupported types
|
||||
if ab_dtype not in {
|
||||
@ -1902,66 +2037,22 @@ def run_grouped_gemm(
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("GPU is required to run this example!")
|
||||
|
||||
# Create tensor and return the pointer, tensor, and stride
|
||||
def create_tensor_and_stride(
|
||||
l: int,
|
||||
mode0: int,
|
||||
mode1: int,
|
||||
is_mode0_major: bool,
|
||||
dtype: type[cutlass.Numeric],
|
||||
is_dynamic_layout: bool = True,
|
||||
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
|
||||
torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype)
|
||||
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
|
||||
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
|
||||
)
|
||||
return (
|
||||
torch_tensor.data_ptr(),
|
||||
torch_tensor,
|
||||
cute_tensor,
|
||||
torch_tensor_cpu,
|
||||
torch_tensor.stride()[:-1],
|
||||
)
|
||||
# Create tensors for all groups using the new function
|
||||
(
|
||||
ptrs_abc,
|
||||
torch_tensors_abc,
|
||||
cute_tensors_abc,
|
||||
strides_abc,
|
||||
torch_fp32_tensors_abc,
|
||||
) = create_tensors_for_all_groups(
|
||||
problem_sizes_mnkl,
|
||||
ab_dtype,
|
||||
c_dtype,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
)
|
||||
|
||||
# iterate all groups and create tensors for each group
|
||||
torch_fp32_tensors_abc = []
|
||||
torch_tensors_abc = []
|
||||
cute_tensors_abc = []
|
||||
strides_abc = []
|
||||
ptrs_abc = []
|
||||
for _, (m, n, k, l) in enumerate(problem_sizes_mnkl):
|
||||
(
|
||||
ptr_a,
|
||||
torch_tensor_a,
|
||||
cute_tensor_a,
|
||||
tensor_fp32_a,
|
||||
stride_mk_a,
|
||||
) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype)
|
||||
(
|
||||
ptr_b,
|
||||
torch_tensor_b,
|
||||
cute_tensor_b,
|
||||
tensor_fp32_b,
|
||||
stride_nk_b,
|
||||
) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype)
|
||||
(
|
||||
ptr_c,
|
||||
torch_tensor_c,
|
||||
cute_tensor_c,
|
||||
tensor_fp32_c,
|
||||
stride_mn_c,
|
||||
) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype)
|
||||
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
|
||||
torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
|
||||
torch_fp32_tensors_abc.append([tensor_fp32_a, tensor_fp32_b, tensor_fp32_c])
|
||||
strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c])
|
||||
cute_tensors_abc.append(
|
||||
(
|
||||
cute_tensor_a,
|
||||
cute_tensor_b,
|
||||
cute_tensor_c,
|
||||
)
|
||||
)
|
||||
# Choose A, B, C with the smallest size to create initial tensormaps
|
||||
key_size_a = lambda item: item[1][0] * item[1][2]
|
||||
key_size_b = lambda item: item[1][1] * item[1][2]
|
||||
@ -2078,36 +2169,19 @@ def run_grouped_gemm(
|
||||
current_stream,
|
||||
)
|
||||
|
||||
# Launch GPU kernel
|
||||
# Warm up
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_grouped_gemm(
|
||||
initial_cute_tensors_abc[0],
|
||||
initial_cute_tensors_abc[1],
|
||||
initial_cute_tensors_abc[2],
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc,
|
||||
tensor_of_ptrs_abc,
|
||||
tensor_of_tensormap,
|
||||
current_stream,
|
||||
)
|
||||
# Execution
|
||||
for i in range(iterations):
|
||||
compiled_grouped_gemm(
|
||||
initial_cute_tensors_abc[0],
|
||||
initial_cute_tensors_abc[1],
|
||||
initial_cute_tensors_abc[2],
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc,
|
||||
tensor_of_ptrs_abc,
|
||||
tensor_of_tensormap,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Compute reference result
|
||||
if not skip_ref_check:
|
||||
compiled_grouped_gemm(
|
||||
initial_cute_tensors_abc[0],
|
||||
initial_cute_tensors_abc[1],
|
||||
initial_cute_tensors_abc[2],
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc,
|
||||
tensor_of_ptrs_abc,
|
||||
tensor_of_tensormap,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
# Compute reference result
|
||||
for i, (a, b, c) in enumerate(torch_tensors_abc):
|
||||
ref = torch.einsum(
|
||||
"mkl,nkl->mnl",
|
||||
@ -2122,6 +2196,102 @@ def run_grouped_gemm(
|
||||
rtol=1e-05,
|
||||
)
|
||||
|
||||
def generate_tensors():
|
||||
# Reuse existing CPU tensors and create new GPU tensors from them
|
||||
(
|
||||
ptrs_abc_workspace,
|
||||
torch_tensors_abc_workspace,
|
||||
cute_tensors_abc_workspace,
|
||||
strides_abc_workspace,
|
||||
_,
|
||||
) = create_tensors_for_all_groups(
|
||||
problem_sizes_mnkl,
|
||||
ab_dtype,
|
||||
c_dtype,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
torch_fp32_tensors_abc,
|
||||
)
|
||||
|
||||
initial_cute_tensors_abc_workspace = [
|
||||
cute_tensors_abc_workspace[min_a_idx][0], # A with smallest (m, k)
|
||||
cute_tensors_abc_workspace[min_b_idx][1], # B with smallest (n, k)
|
||||
cute_tensors_abc_workspace[min_c_idx][2], # C with smallest (m, n)
|
||||
]
|
||||
|
||||
# Create new tensors for this workspace
|
||||
tensor_of_strides_abc_workspace, _ = cutlass_torch.cute_tensor_like(
|
||||
torch.tensor(strides_abc_workspace, dtype=torch.int32),
|
||||
cutlass.Int32,
|
||||
is_dynamic_layout=False,
|
||||
assumed_align=16,
|
||||
)
|
||||
|
||||
tensor_of_ptrs_abc_workspace, _ = cutlass_torch.cute_tensor_like(
|
||||
torch.tensor(ptrs_abc_workspace, dtype=torch.int64),
|
||||
cutlass.Int64,
|
||||
is_dynamic_layout=False,
|
||||
assumed_align=16,
|
||||
)
|
||||
|
||||
tensormap_workspace, _ = cutlass_torch.cute_tensor_like(
|
||||
torch.empty(tensormap_shape, dtype=torch.int64),
|
||||
cutlass.Int64,
|
||||
is_dynamic_layout=False,
|
||||
)
|
||||
|
||||
return testing.JitArguments(
|
||||
initial_cute_tensors_abc_workspace[0],
|
||||
initial_cute_tensors_abc_workspace[1],
|
||||
initial_cute_tensors_abc_workspace[2],
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc_workspace,
|
||||
tensor_of_ptrs_abc_workspace,
|
||||
tensormap_workspace,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
sum(
|
||||
[
|
||||
sum(
|
||||
[
|
||||
torch_tensor.numel() * torch_tensor.element_size()
|
||||
for torch_tensor in group_tensors
|
||||
]
|
||||
)
|
||||
for group_tensors in torch_tensors_abc
|
||||
]
|
||||
)
|
||||
+
|
||||
# Add size of strides tensor
|
||||
tensor_of_strides_abc_torch.numel()
|
||||
* tensor_of_strides_abc_torch.element_size()
|
||||
+
|
||||
# Add size of ptrs tensor
|
||||
tensor_of_ptrs_abc_torch.numel() * tensor_of_ptrs_abc_torch.element_size()
|
||||
+
|
||||
# Add size of tensormap tensor
|
||||
tensor_of_tensormap_torch.numel() * tensor_of_tensormap_torch.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_grouped_gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -2218,6 +2388,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -2248,7 +2424,7 @@ if __name__ == "__main__":
|
||||
|
||||
torch.manual_seed(2025)
|
||||
|
||||
run_grouped_gemm(
|
||||
run(
|
||||
args.num_groups,
|
||||
args.problem_sizes_mnkl,
|
||||
args.ab_dtype,
|
||||
@ -2265,5 +2441,6 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
@ -29,13 +29,14 @@
|
||||
|
||||
import argparse
|
||||
from typing import List, Type, Tuple, Optional
|
||||
from cuda import cuda
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
@ -43,13 +44,16 @@ import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
from .mamba2_ssd_reference import (
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent))
|
||||
from mamba2_ssd_reference import (
|
||||
ssd_reference_fp32_all,
|
||||
ssd_reference_lowprecision_intermediates,
|
||||
analyze_relative_diffs,
|
||||
)
|
||||
|
||||
from .mamba2_ssd_tile_scheduler import (
|
||||
from mamba2_ssd_tile_scheduler import (
|
||||
Mamba2SSDTileSchedulerParams,
|
||||
Mamba2SSDTileScheduler,
|
||||
)
|
||||
@ -122,7 +126,7 @@ class SSDKernel:
|
||||
*self.epilog_warp_id,
|
||||
)
|
||||
)
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
# Named barriers
|
||||
self.pre_inter_sync_bar_id = 1
|
||||
@ -1522,7 +1526,10 @@ class SSDKernel:
|
||||
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N)
|
||||
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE)
|
||||
tiled_r2s_b, tBrB_r2s, tBsB_r2s = self.pre_inter_smem_store_and_partition_b(
|
||||
local_tidx, smem_bt_internal_, tiled_s2r_b, tBrB_s2r
|
||||
local_tidx,
|
||||
smem_bt_internal_,
|
||||
tiled_s2r_b,
|
||||
tBrB_s2r,
|
||||
)
|
||||
|
||||
# (MMA, MMA_M, MMA_K, INPUT_STAGE)
|
||||
@ -3053,7 +3060,7 @@ class SSDKernel:
|
||||
|
||||
# SegSum
|
||||
# fadd2 + fsel + fmul2/mufu + fmul2
|
||||
for subtile_idx in range(0, cute.size(tTR_rQ), 2):
|
||||
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
|
||||
(
|
||||
tCompute[subtile_idx],
|
||||
tCompute[subtile_idx + 1],
|
||||
@ -3061,11 +3068,11 @@ class SSDKernel:
|
||||
(tCrDeltaA_Col[subtile_idx], tCrDeltaA_Col[subtile_idx + 1]),
|
||||
(-tCrDeltaA_Row[subtile_idx], -tCrDeltaA_Row[subtile_idx + 1]),
|
||||
)
|
||||
for subtile_idx in range(cute.size(tTR_rQ)):
|
||||
for subtile_idx in cutlass.range(cute.size(tTR_rQ), unroll_full=True):
|
||||
m, n = tCoord[subtile_idx]
|
||||
if m < n:
|
||||
tCompute[subtile_idx] = cutlass.Float32(-float("inf"))
|
||||
for subtile_idx in range(0, cute.size(tTR_rQ), 2):
|
||||
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
|
||||
# TODO: use math.exp directly
|
||||
(
|
||||
tCompute[subtile_idx],
|
||||
@ -3130,11 +3137,7 @@ class SSDKernel:
|
||||
dtype,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
tiled_r2s_b = cute.make_tiled_copy(
|
||||
copy_atom_r2s_b,
|
||||
layout_tv=tiled_s2r_b.layout_tv_tiled,
|
||||
tiler_mn=tiled_s2r_b.tiler_mn,
|
||||
)
|
||||
tiled_r2s_b = cute.make_tiled_copy_S(copy_atom_r2s_b, tiled_s2r_b)
|
||||
thr_r2s_b = tiled_r2s_b.get_slice(local_tidx)
|
||||
|
||||
# Partition shared tensor for smem store Bt
|
||||
@ -3333,17 +3336,24 @@ class SSDKernel:
|
||||
)
|
||||
|
||||
|
||||
def run_ssd(
|
||||
def run(
|
||||
gbehcdln: Tuple[int, int, int, int, int, int, int, int],
|
||||
io_dtype: Type[cutlass.Numeric],
|
||||
cumsum_delta_dtype: Type[cutlass.Numeric],
|
||||
acc_dtype: Type[cutlass.Numeric],
|
||||
has_d: bool,
|
||||
d_has_hdim: bool,
|
||||
fuse_scale_d: str,
|
||||
tolerance: float,
|
||||
print_rtol_stats: bool,
|
||||
ref_lower_precision: bool,
|
||||
warmup_iterations: int,
|
||||
iterations: int,
|
||||
skip_ref_check: bool,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
has_d = fuse_scale_d != "none"
|
||||
d_has_hdim = fuse_scale_d == "vector"
|
||||
|
||||
print(f"Running B100 Mamba2 SSD with:")
|
||||
print(f"GBEHCDLN: {gbehcdln}")
|
||||
print(
|
||||
@ -3353,6 +3363,10 @@ def run_ssd(
|
||||
f"Has D (True means fuse Y+=X*D): {has_d}, D has Hdim (True means D.shape DxEH, False means 1xEH): {d_has_hdim}"
|
||||
)
|
||||
print(f"Tolerance: {tolerance}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||||
|
||||
# Unpack parameters
|
||||
G, B, E, H, C, D, L, N = gbehcdln
|
||||
@ -3515,39 +3529,146 @@ def run_ssd(
|
||||
stream,
|
||||
)
|
||||
|
||||
# Launch compiled ssd kernel
|
||||
compiled_ssd(
|
||||
x_tensor,
|
||||
cumsum_delta_tensor,
|
||||
delta_tensor,
|
||||
b_tensor,
|
||||
c_tensor,
|
||||
y_tensor,
|
||||
fstate_tensor,
|
||||
d_tensor,
|
||||
stream,
|
||||
# Launch compiled ssd kernel for reference check
|
||||
if not skip_ref_check:
|
||||
compiled_ssd(
|
||||
x_tensor,
|
||||
cumsum_delta_tensor,
|
||||
delta_tensor,
|
||||
b_tensor,
|
||||
c_tensor,
|
||||
y_tensor,
|
||||
fstate_tensor,
|
||||
d_tensor,
|
||||
stream,
|
||||
)
|
||||
|
||||
# Reference check
|
||||
if print_rtol_stats:
|
||||
print("\nY's Relative diffs:")
|
||||
analyze_relative_diffs(
|
||||
y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype))
|
||||
)
|
||||
print("\nFstate's Relative diffs:")
|
||||
analyze_relative_diffs(
|
||||
fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype))
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
y_torch.cpu(),
|
||||
y_ref.to(cutlass_torch.dtype(io_dtype)),
|
||||
atol=tolerance,
|
||||
rtol=1e-02,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
fstate_torch.cpu(),
|
||||
fstate_ref.to(cutlass_torch.dtype(io_dtype)),
|
||||
atol=tolerance,
|
||||
rtol=1e-05,
|
||||
)
|
||||
|
||||
def generate_tensors():
|
||||
# Reuse existing CPU reference tensors and create new GPU tensors from them
|
||||
_, x_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, D, C, L],
|
||||
[2, 4, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=x_ref,
|
||||
dynamic_modes=[2, 3, 4],
|
||||
)
|
||||
_, cumsum_delta_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, C, L],
|
||||
[3, 2, 1, 0],
|
||||
cumsum_delta_dtype,
|
||||
ref_tensor=cumsum_delta_ref,
|
||||
dynamic_modes=[1, 2, 3],
|
||||
)
|
||||
_, delta_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, C, L],
|
||||
[3, 2, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=delta_ref,
|
||||
dynamic_modes=[1, 2, 3],
|
||||
)
|
||||
_, b_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, G, N, C, L],
|
||||
[4, 2, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=b_ref,
|
||||
dynamic_modes=[2, 3, 4],
|
||||
)
|
||||
_, c_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, G, N, C, L],
|
||||
[4, 2, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=c_ref,
|
||||
dynamic_modes=[2, 3, 4],
|
||||
)
|
||||
_, y_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, D, C, L],
|
||||
[4, 2, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=y_ref,
|
||||
dynamic_modes=[2, 3, 4],
|
||||
)
|
||||
_, fstate_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, D, N],
|
||||
[2, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=fstate_ref,
|
||||
dynamic_modes=[2, 3],
|
||||
)
|
||||
|
||||
if has_d:
|
||||
_, d_tensor_new, _ = create_and_permute_tensor(
|
||||
[EH, D if d_has_hdim else 1],
|
||||
[1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=d_ref,
|
||||
dynamic_modes=[1],
|
||||
)
|
||||
else:
|
||||
d_tensor_new = d_tensor
|
||||
|
||||
return testing.JitArguments(
|
||||
x_tensor_new,
|
||||
cumsum_delta_tensor_new,
|
||||
delta_tensor_new,
|
||||
b_tensor_new,
|
||||
c_tensor_new,
|
||||
y_tensor_new,
|
||||
fstate_tensor_new,
|
||||
d_tensor_new,
|
||||
stream,
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
x_torch.numel() * x_torch.element_size()
|
||||
+ cumsum_delta_torch.numel() * cumsum_delta_torch.element_size()
|
||||
+ delta_torch.numel() * delta_torch.element_size()
|
||||
+ b_torch.numel() * b_torch.element_size()
|
||||
+ c_torch.numel() * c_torch.element_size()
|
||||
+ y_torch.numel() * y_torch.element_size()
|
||||
+ fstate_torch.numel() * fstate_torch.element_size()
|
||||
)
|
||||
if has_d:
|
||||
one_workspace_bytes += d_torch.numel() * d_torch.element_size()
|
||||
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_ssd,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
# Reference check
|
||||
if print_rtol_stats:
|
||||
print("\nY's Relative diffs:")
|
||||
analyze_relative_diffs(y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype)))
|
||||
print("\nFstate's Relative diffs:")
|
||||
analyze_relative_diffs(
|
||||
fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype))
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
y_torch.cpu(),
|
||||
y_ref.to(cutlass_torch.dtype(io_dtype)),
|
||||
atol=tolerance,
|
||||
rtol=1e-02,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
fstate_torch.cpu(),
|
||||
fstate_ref.to(cutlass_torch.dtype(io_dtype)),
|
||||
atol=tolerance,
|
||||
rtol=1e-05,
|
||||
)
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -3586,15 +3707,53 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref_lower_precision",
|
||||
type=bool,
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Use lower precision for reference check",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-ref_lower_precision",
|
||||
action="store_false",
|
||||
dest="ref_lower_precision",
|
||||
default=False,
|
||||
help="Disable lower precision for reference check",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=5e-02, help="Tolerance for validation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_rtol_stats", type=bool, default=True, help="Print rtol stats"
|
||||
"--print_rtol_stats",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Enable print rtol stats",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-print_rtol_stats",
|
||||
action="store_false",
|
||||
dest="print_rtol_stats",
|
||||
default=False,
|
||||
help="Disable print rtol stats",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_iterations",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of warmup iterations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of iterations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -3602,18 +3761,18 @@ if __name__ == "__main__":
|
||||
if len(args.gbehcdln) != 8:
|
||||
parser.error("--gbehcdln must contain exactly 8 values")
|
||||
|
||||
has_d = args.fuse_scale_d != "none"
|
||||
d_has_hdim = args.fuse_scale_d == "vector"
|
||||
|
||||
run_ssd(
|
||||
run(
|
||||
args.gbehcdln,
|
||||
args.io_dtype,
|
||||
args.cumsum_delta_dtype,
|
||||
args.acc_dtype,
|
||||
has_d,
|
||||
d_has_hdim,
|
||||
args.fuse_scale_d,
|
||||
args.tolerance,
|
||||
args.print_rtol_stats,
|
||||
args.ref_lower_precision,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
@ -35,6 +35,7 @@ import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.torch as cutlass_torch
|
||||
@ -166,6 +167,24 @@ def parse_arguments() -> argparse.Namespace:
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of iterations to run the kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -264,7 +283,7 @@ class HopperWgmmaGemmKernel:
|
||||
self.mma_warp_groups = math.prod(self.atom_layout_mnk)
|
||||
self.num_threads_per_warp_group = 128
|
||||
self.threads_per_cta = self.mma_warp_groups * self.num_threads_per_warp_group
|
||||
self.smem_capacity = sm90_utils.SMEM_CAPACITY["sm90"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
|
||||
|
||||
self.ab_stage = None
|
||||
self.epi_stage = None
|
||||
@ -1309,7 +1328,7 @@ class HopperWgmmaGemmKernel:
|
||||
}:
|
||||
is_valid = False
|
||||
# tested acc_dtype
|
||||
if acc_dtype != cutlass.Float32:
|
||||
if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
|
||||
is_valid = False
|
||||
# tested c_dtype
|
||||
if c_dtype not in {
|
||||
@ -1335,7 +1354,7 @@ class HopperWgmmaGemmKernel:
|
||||
return is_valid
|
||||
|
||||
|
||||
def run_dense_gemm(
|
||||
def run(
|
||||
mnkl: Tuple[int, int, int, int],
|
||||
a_dtype: Type[cutlass.Numeric],
|
||||
b_dtype: Type[cutlass.Numeric],
|
||||
@ -1347,9 +1366,43 @@ def run_dense_gemm(
|
||||
tile_shape_mnk: Tuple[int, int, int],
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
tolerance: float,
|
||||
warmup_iterations: int,
|
||||
iterations: int,
|
||||
skip_ref_check: bool,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
|
||||
|
||||
:param mnkl: Problem size (M, N, K, L)
|
||||
:type mnkl: Tuple[int, int, int, int]
|
||||
:param a_dtype: Data type for input tensor A
|
||||
:type a_dtype: Type[cutlass.Numeric]
|
||||
:param b_dtype: Data type for input tensor B
|
||||
:type b_dtype: Type[cutlass.Numeric]
|
||||
:param c_dtype: Data type for output tensor C
|
||||
:type c_dtype: Type[cutlass.Numeric]
|
||||
:param acc_dtype: Data type for accumulation during matrix multiplication
|
||||
:type acc_dtype: Type[cutlass.Numeric]
|
||||
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
|
||||
:type a_major/b_major/c_major: str
|
||||
:param tile_shape_mnk: CTA tile shape (M, N, K)
|
||||
:type tile_shape_mnk: Tuple[int, int, int]
|
||||
:param cluster_shape_mn: Cluster shape (M, N)
|
||||
:type cluster_shape_mn: Tuple[int, int]
|
||||
:param tolerance: Tolerance value for reference validation comparison
|
||||
:type tolerance: float
|
||||
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
||||
:type warmup_iterations: int, optional
|
||||
:param iterations: Number of benchmark iterations to run, defaults to 1
|
||||
:type iterations: int, optional
|
||||
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
||||
:type skip_ref_check: bool, optional
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||||
:type use_cold_l2: bool, optional
|
||||
:return: Execution time of the GEMM kernel in microseconds
|
||||
:rtype: float
|
||||
"""
|
||||
|
||||
print(f"Running Hopper Dense GEMM with:")
|
||||
@ -1360,6 +1413,10 @@ def run_dense_gemm(
|
||||
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
||||
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
|
||||
print(f"Tolerance: {tolerance}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {use_cold_l2}")
|
||||
|
||||
# Unpack parameters
|
||||
m, n, k, l = mnkl
|
||||
@ -1437,46 +1494,76 @@ def run_dense_gemm(
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
# compile gemm kernel
|
||||
compiled_gemm = cute.compile(gemm, mA, mB, mC, stream)
|
||||
# execution
|
||||
compiled_gemm(mA, mB, mC, stream)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if not skip_ref_check:
|
||||
# execution
|
||||
compiled_gemm(mA, mB, mC, stream)
|
||||
|
||||
# Ref check
|
||||
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
|
||||
# m major: (l, n, m) -> (m, n, l)
|
||||
# k major: (l, m, n) -> (m, n, l)
|
||||
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
|
||||
shape = (l, m, n) if c_major == "n" else (l, n, m)
|
||||
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
||||
shape,
|
||||
torch.uint8,
|
||||
permute_order=permute_order,
|
||||
init_type=cutlass_torch.TensorInitType.SKIP,
|
||||
).cuda()
|
||||
# Create dtype cute tensor (gpu)
|
||||
ref_c_tensor = from_dlpack(
|
||||
f8_torch_tensor, assumed_align=16
|
||||
).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
ref_c_tensor.element_type = c_dtype
|
||||
ref_c_tensor = cutlass_torch.convert_cute_tensor(
|
||||
ref,
|
||||
ref_c_tensor,
|
||||
c_dtype,
|
||||
is_dynamic_layout=True,
|
||||
# Ref check
|
||||
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
|
||||
|
||||
if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
|
||||
# m major: (l, n, m) -> (m, n, l)
|
||||
# n major: (l, m, n) -> (m, n, l)
|
||||
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
|
||||
shape = (l, m, n) if c_major == "n" else (l, n, m)
|
||||
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
||||
shape,
|
||||
torch.uint8,
|
||||
permute_order=permute_order,
|
||||
init_type=cutlass_torch.TensorInitType.SKIP,
|
||||
).cuda()
|
||||
# Create dtype cute tensor (gpu)
|
||||
ref_c_tensor = from_dlpack(
|
||||
f8_torch_tensor, assumed_align=16
|
||||
).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
ref_c_tensor.element_type = c_dtype
|
||||
ref_c_tensor = cutlass_torch.convert_cute_tensor(
|
||||
ref,
|
||||
ref_c_tensor,
|
||||
c_dtype,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
ref_c = f8_torch_tensor.cpu()
|
||||
else:
|
||||
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
|
||||
|
||||
torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
|
||||
|
||||
def generate_tensors():
|
||||
_, mA_workspace, _ = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
|
||||
_, mB_workspace, _ = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
|
||||
_, mC_workspace, _ = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
|
||||
return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
a_torch.numel() * a_torch.element_size()
|
||||
+ b_torch.numel() * b_torch.element_size()
|
||||
+ c_torch.numel() * c_torch.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
ref_c = f8_torch_tensor.cpu()
|
||||
else:
|
||||
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
|
||||
|
||||
torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
|
||||
exec_time = testing.benchmark(
|
||||
compiled_gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
run_dense_gemm(
|
||||
run(
|
||||
args.mnkl,
|
||||
args.a_dtype,
|
||||
args.b_dtype,
|
||||
@ -1488,5 +1575,9 @@ if __name__ == "__main__":
|
||||
args.tile_shape_mnk,
|
||||
args.cluster_shape_mn,
|
||||
args.tolerance,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
@ -399,6 +399,70 @@
|
||||
"\n",
|
||||
"tensor_print_example3()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To print the tensor in device memory, you can use `cute.print_tensor` within CuTe JIT kernels."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def print_tensor_gpu(src: cute.Tensor):\n",
|
||||
" print(src)\n",
|
||||
" cute.print_tensor(src)\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def print_tensor_host(src: cute.Tensor):\n",
|
||||
" print_tensor_gpu(src).launch(grid=(1,1,1), block=(1,1,1))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor<ptr<f32, gmem> o (4,3):(3,1)>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x00007f5f81200400: f32, gmem, align<4>) o (4,3):(3,1), data=\n",
|
||||
" [[-0.690547, -0.274619, -1.659539, ],\n",
|
||||
" [-1.843524, -1.648711, 1.163431, ],\n",
|
||||
" [-0.716668, -1.900705, 0.592515, ],\n",
|
||||
" [ 0.711333, -0.552422, 0.860237, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"def tensor_print_example4():\n",
|
||||
" a = torch.randn(4, 3, device=\"cuda\")\n",
|
||||
" cutlass.cuda.initialize_cuda_context()\n",
|
||||
" print_tensor_host(from_dlpack(a))\n",
|
||||
"\n",
|
||||
"tensor_print_example4()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Currently, `cute.print_tensor` only supports tensor with integer data types and `Float16`/`Float32`/`Float64` floating point data types. We will support more data types in the future."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@ -256,16 +256,6 @@
|
||||
" cute.printf(\"a[2,3] = {}\", a[2,3])\n",
|
||||
" cute.printf(\"a[(2,4)] = {}\", a[(2,4)])\n",
|
||||
"\n",
|
||||
"@cute.kernel\n",
|
||||
"def print_tensor_gpu(ptr: cute.Pointer):\n",
|
||||
" layout = cute.make_layout((8, 5), stride=(5, 1))\n",
|
||||
" tensor = cute.make_tensor(ptr, layout)\n",
|
||||
"\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
"\n",
|
||||
" if tidx == 0:\n",
|
||||
" cute.print_tensor(tensor)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create a tensor with sequential data using torch\n",
|
||||
"data = torch.arange(0, 8*5, dtype=torch.float32).reshape(8, 5)\n",
|
||||
|
||||
@ -363,7 +363,7 @@
|
||||
"| | \"few_channels\" | optimized for small `C` and requires `C % alignment_input == 0`|\n",
|
||||
"| | \"fixed_channels\" | optimized for small `C` and requires `C == alignment_input` |\n",
|
||||
"|Dgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n",
|
||||
"| | \"optimized\" | Optimzed for and require `R <= 32`, `S<= 32`, `K % alignment_grad_output == 0`, and `C % alignment_weight == 0`|\n",
|
||||
"| | \"optimized\" | Optimized for and require `R <= 32`, `S<= 32`, `K % alignment_grad_output == 0`, and `C % alignment_weight == 0`|\n",
|
||||
"|Wgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n",
|
||||
"| | \"optimized\" | Optimized for and require `K % alignment_grad_output == 0`, and `C % alignment_input == 0`|\n",
|
||||
"\n",
|
||||
|
||||
Reference in New Issue
Block a user