v4.1 release update v2. (#2481)
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@ -212,7 +212,7 @@ class DenseGemmKernel:
|
||||
|
||||
self.occupancy = 1
|
||||
self.threads_per_cta = 128
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
def _setup_attributes(self):
|
||||
"""Set up configurations that are dependent on GEMM inputs
|
||||
@ -1106,11 +1106,7 @@ class DenseGemmKernel:
|
||||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy(
|
||||
copy_atom_r2s,
|
||||
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_copy_t2r.tiler_mn,
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
@ -1772,7 +1768,7 @@ def run_dense_gemm(
|
||||
ref_c = ref
|
||||
elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
|
||||
# m major: (l, n, m) -> (m, n, l)
|
||||
# k major: (l, m, n) -> (m, n, l)
|
||||
# n major: (l, m, n) -> (m, n, l)
|
||||
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
|
||||
shape = (l, m, n) if c_major == "n" else (l, n, m)
|
||||
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
||||
|
||||
@ -38,6 +38,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
@ -226,7 +227,7 @@ class PersistentDenseGemmKernel:
|
||||
self.cta_sync_bar_id = 0
|
||||
self.epilog_sync_bar_id = 1
|
||||
self.tmem_ptr_sync_bar_id = 2
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
def _setup_attributes(self):
|
||||
"""Set up configurations that are dependent on GEMM inputs
|
||||
@ -1308,11 +1309,7 @@ class PersistentDenseGemmKernel:
|
||||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy(
|
||||
copy_atom_r2s,
|
||||
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_copy_t2r.tiler_mn,
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
@ -1824,7 +1821,7 @@ class PersistentDenseGemmKernel:
|
||||
return can_implement
|
||||
|
||||
|
||||
def run_dense_gemm(
|
||||
def run(
|
||||
mnkl: Tuple[int, int, int, int],
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
@ -1832,17 +1829,58 @@ def run_dense_gemm(
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
mma_tiler_mn: Tuple[int, int],
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
use_2cta_instrs: bool,
|
||||
use_tma_store: bool,
|
||||
tolerance: float,
|
||||
mma_tiler_mn: Tuple[int, int] = (256, 256),
|
||||
cluster_shape_mn: Tuple[int, int] = (2, 1),
|
||||
use_2cta_instrs: bool = True,
|
||||
use_tma_store: bool = True,
|
||||
tolerance: float = 1e-01,
|
||||
warmup_iterations: int = 0,
|
||||
iterations: int = 1,
|
||||
skip_ref_check: bool = False,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
|
||||
"""Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking.
|
||||
|
||||
This function prepares input tensors, configures and launches the persistent GEMM kernel,
|
||||
optionally performs reference validation, and benchmarks the execution performance.
|
||||
|
||||
:param mnkl: Problem size (M, N, K, L)
|
||||
:type mnkl: Tuple[int, int, int, int]
|
||||
:param ab_dtype: Data type for input tensors A and B
|
||||
:type ab_dtype: Type[cutlass.Numeric]
|
||||
:param c_dtype: Data type for output tensor C
|
||||
:type c_dtype: Type[cutlass.Numeric]
|
||||
:param acc_dtype: Data type for accumulation during matrix multiplication
|
||||
:type acc_dtype: Type[cutlass.Numeric]
|
||||
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
|
||||
:type a_major/b_major/c_major: str
|
||||
:param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the
|
||||
default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type mma_tiler_mn: Tuple[int, int], optional
|
||||
:param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
|
||||
default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type cluster_shape_mn: Tuple[int, int], optional
|
||||
:param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
|
||||
will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type use_2cta_instrs: bool, optional
|
||||
:param use_tma_store: Whether to use TMA store. If not specified in the decorator parameters, the autotuner will use
|
||||
the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type use_tma_store: bool, optional
|
||||
:param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
|
||||
:type tolerance: float, optional
|
||||
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
||||
:type warmup_iterations: int, optional
|
||||
:param iterations: Number of benchmark iterations to run, defaults to 1
|
||||
:type iterations: int, optional
|
||||
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
||||
:type skip_ref_check: bool, optional
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||||
:type use_cold_l2: bool, optional
|
||||
:raises RuntimeError: If CUDA GPU is not available
|
||||
:raises ValueError: If the configuration is invalid or unsupported by the kernel
|
||||
:return: Execution time of the GEMM kernel
|
||||
:rtype: float
|
||||
"""
|
||||
print(f"Running Blackwell Persistent Dense GEMM test with:")
|
||||
print(f"mnkl: {mnkl}")
|
||||
@ -1855,6 +1893,7 @@ def run_dense_gemm(
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||||
|
||||
# Unpack parameters
|
||||
m, n, k, l = mnkl
|
||||
@ -1931,15 +1970,15 @@ def run_dense_gemm(
|
||||
is_dynamic_layout=is_dynamic_layout,
|
||||
)
|
||||
|
||||
return f32_torch_tensor, cute_tensor, torch_tensor
|
||||
return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu
|
||||
|
||||
a_ref, a_tensor, a_torch = create_and_permute_tensor(
|
||||
a_ref, a_tensor, a_torch, a_torch_cpu = create_and_permute_tensor(
|
||||
l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True
|
||||
)
|
||||
b_ref, b_tensor, b_torch = create_and_permute_tensor(
|
||||
b_ref, b_tensor, b_torch, b_torch_cpu = create_and_permute_tensor(
|
||||
l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True
|
||||
)
|
||||
c_ref, c_tensor, c_torch = create_and_permute_tensor(
|
||||
c_ref, c_tensor, c_torch, c_torch_cpu = create_and_permute_tensor(
|
||||
l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True
|
||||
)
|
||||
|
||||
@ -1967,16 +2006,8 @@ def run_dense_gemm(
|
||||
gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, current_stream
|
||||
)
|
||||
|
||||
# Launch GPU kernel
|
||||
# Warm up
|
||||
for i in range(warmup_iterations):
|
||||
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
# Execution
|
||||
for i in range(iterations):
|
||||
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
|
||||
# Compute reference result
|
||||
if not skip_ref_check:
|
||||
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
if ab_dtype in {
|
||||
cutlass.Int8,
|
||||
cutlass.Uint8,
|
||||
@ -2028,6 +2059,40 @@ def run_dense_gemm(
|
||||
rtol=1e-05,
|
||||
)
|
||||
|
||||
def generate_tensors():
|
||||
a_tensor, _ = cutlass_torch.cute_tensor_like(
|
||||
a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
|
||||
)
|
||||
b_tensor, _ = cutlass_torch.cute_tensor_like(
|
||||
b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
|
||||
)
|
||||
c_tensor, _ = cutlass_torch.cute_tensor_like(
|
||||
c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16
|
||||
)
|
||||
return testing.JitArguments(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
a_torch_cpu.numel() * a_torch_cpu.element_size()
|
||||
+ b_torch_cpu.numel() * b_torch_cpu.element_size()
|
||||
+ c_torch_cpu.numel() * c_torch_cpu.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -2090,6 +2155,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -2102,7 +2173,7 @@ if __name__ == "__main__":
|
||||
if len(args.cluster_shape_mn) != 2:
|
||||
parser.error("--cluster_shape_mn must contain exactly 2 values")
|
||||
|
||||
run_dense_gemm(
|
||||
run(
|
||||
args.mnkl,
|
||||
args.ab_dtype,
|
||||
args.c_dtype,
|
||||
@ -2118,5 +2189,6 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
@ -223,7 +223,7 @@ class DenseGemmKernel:
|
||||
|
||||
self.occupancy = 1
|
||||
self.threads_per_cta = 128
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
def _setup_attributes(self):
|
||||
"""Set up configurations that are dependent on GEMM inputs
|
||||
@ -1063,11 +1063,7 @@ class DenseGemmKernel:
|
||||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy(
|
||||
copy_atom_r2s,
|
||||
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_copy_t2r.tiler_mn,
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
|
||||
@ -43,6 +43,7 @@ import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.cute.testing as testing
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass.cute.typing import Int32, Int64, Float32, Boolean
|
||||
|
||||
@ -90,7 +91,7 @@ Constraints for this example:
|
||||
* Number of heads in Q must be divisible by number of heads in K
|
||||
* mma_tiler_mn must be 128,128
|
||||
* Batch size must be the same for Q, K, and V tensors
|
||||
* For causal masking, use --has_casual_mask (note: specify without =True/False)
|
||||
* For causal masking, use --is_causal (note: specify without =True/False)
|
||||
* For persistent scheduling, use --is_persistent (note: specify without =True/False)
|
||||
"""
|
||||
|
||||
@ -2373,11 +2374,7 @@ class BlackwellFusedMultiHeadAttentionForward:
|
||||
smem_copy_atom = sm100_utils.get_smem_store_op(
|
||||
self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load
|
||||
)
|
||||
tiled_smem_store = cute.make_tiled_copy(
|
||||
smem_copy_atom,
|
||||
layout_tv=tiled_tmem_load.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_tmem_load.tiler_mn,
|
||||
)
|
||||
tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load)
|
||||
|
||||
tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None])
|
||||
tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None])
|
||||
@ -2619,7 +2616,7 @@ class BlackwellFusedMultiHeadAttentionForward:
|
||||
return tile_sched_params, grid
|
||||
|
||||
|
||||
def run_fmha_and_verify(
|
||||
def run(
|
||||
q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
|
||||
k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
|
||||
in_dtype: Type[cutlass.Numeric],
|
||||
@ -2628,7 +2625,7 @@ def run_fmha_and_verify(
|
||||
pv_acc_dtype: Type[cutlass.Numeric],
|
||||
mma_tiler_mn: Tuple[int, int],
|
||||
is_persistent: bool,
|
||||
has_casual_mask: bool,
|
||||
is_causal: bool,
|
||||
scale_q: float,
|
||||
scale_k: float,
|
||||
scale_v: float,
|
||||
@ -2638,6 +2635,8 @@ def run_fmha_and_verify(
|
||||
warmup_iterations: int,
|
||||
iterations: int,
|
||||
skip_ref_check: bool,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Execute Fused Multi-Head Attention (FMHA) on Blackwell architecture and validate results.
|
||||
|
||||
@ -2670,8 +2669,8 @@ def run_fmha_and_verify(
|
||||
:type mma_tiler_mn: Tuple[int, int]
|
||||
:param is_persistent: Whether to use persistent kernel optimization
|
||||
:type is_persistent: bool
|
||||
:param has_casual_mask: Whether to apply causal masking
|
||||
:type has_casual_mask: bool
|
||||
:param is_causal: Whether to apply causal masking
|
||||
:type is_causal: bool
|
||||
:param scale_q: Scaling factor for query tensor
|
||||
:type scale_q: float
|
||||
:param scale_k: Scaling factor for key tensor
|
||||
@ -2690,9 +2689,13 @@ def run_fmha_and_verify(
|
||||
:type iterations: int
|
||||
:param skip_ref_check: Skip validation against reference implementation
|
||||
:type skip_ref_check: bool
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache
|
||||
:type use_cold_l2: bool
|
||||
|
||||
:raises ValueError: If input shapes are incompatible or head dimension is unsupported
|
||||
:raises RuntimeError: If GPU is unavailable for computation
|
||||
:return: Execution time of the FMHA kernel in microseconds
|
||||
:rtype: float
|
||||
"""
|
||||
|
||||
print(f"Running Blackwell SM100 FMHA test with:")
|
||||
@ -2704,13 +2707,17 @@ def run_fmha_and_verify(
|
||||
print(f" pv_acc_dtype: {pv_acc_dtype}")
|
||||
print(f" mma_tiler_mn: {mma_tiler_mn}")
|
||||
print(f" is_persistent: {is_persistent}")
|
||||
print(f" has_casual_mask: {has_casual_mask}")
|
||||
print(f" is_causal: {is_causal}")
|
||||
print(f" scale_q: {scale_q}")
|
||||
print(f" scale_k: {scale_k}")
|
||||
print(f" scale_v: {scale_v}")
|
||||
print(f" inv_scale_o: {inv_scale_o}")
|
||||
print(f" scale_softmax: {scale_softmax}")
|
||||
print(f" tolerance: {tolerance}")
|
||||
print(f" warmup_iterations: {warmup_iterations}")
|
||||
print(f" iterations: {iterations}")
|
||||
print(f" skip_ref_check: {skip_ref_check}")
|
||||
print(f" use_cold_l2: {use_cold_l2}")
|
||||
|
||||
# Unpack parameters
|
||||
b, s_q, h_q, d = q_shape
|
||||
@ -2882,7 +2889,7 @@ def run_fmha_and_verify(
|
||||
mma_tiler = (*mma_tiler_mn, d)
|
||||
|
||||
mask_type = MaskType.NO_MASK
|
||||
if has_casual_mask:
|
||||
if is_causal:
|
||||
mask_type = MaskType.CAUSAL_MASK
|
||||
else:
|
||||
if isinstance(s_k, tuple):
|
||||
@ -2942,41 +2949,7 @@ def run_fmha_and_verify(
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_fmha(
|
||||
q_tensor.iterator,
|
||||
k_tensor.iterator,
|
||||
v_tensor.iterator,
|
||||
o_tensor.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
# Execute kernel
|
||||
for _ in range(iterations):
|
||||
compiled_fmha(
|
||||
q_tensor.iterator,
|
||||
k_tensor.iterator,
|
||||
v_tensor.iterator,
|
||||
o_tensor.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def run_torch_fmha(
|
||||
q, k, v, scale_softmax=1.0, scale_output=1.0, has_casual_mask=False
|
||||
):
|
||||
def run_torch_fmha(q, k, v, scale_softmax=1.0, scale_output=1.0, is_causal=False):
|
||||
h_q = q.shape[2]
|
||||
h_k = k.shape[2]
|
||||
|
||||
@ -3005,7 +2978,7 @@ def run_fmha_and_verify(
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# For the situation that torch has not supported, we need to handle it manually
|
||||
situation1 = has_casual_mask and (q.is_nested or k.is_nested)
|
||||
situation1 = is_causal and (q.is_nested or k.is_nested)
|
||||
situation2 = (q.is_nested and not k.is_nested) or (
|
||||
not q.is_nested and k.is_nested
|
||||
)
|
||||
@ -3025,8 +2998,9 @@ def run_fmha_and_verify(
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
scale=scale_softmax,
|
||||
is_causal=has_casual_mask,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
ref_i = ref_i.transpose(0, 1) * scale_output
|
||||
ref_list.append(ref_i)
|
||||
if q.is_nested:
|
||||
ref = torch.nested.nested_tensor(ref_list, layout=torch.jagged)
|
||||
@ -3040,15 +3014,28 @@ def run_fmha_and_verify(
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
scale=scale_softmax,
|
||||
is_causal=has_casual_mask,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
ref = ref.transpose(1, 2) * scale_output
|
||||
ref = ref.transpose(1, 2) * scale_output
|
||||
return ref
|
||||
|
||||
if not skip_ref_check:
|
||||
# Execute kernel once for reference checking
|
||||
compiled_fmha(
|
||||
q_tensor.iterator,
|
||||
k_tensor.iterator,
|
||||
v_tensor.iterator,
|
||||
o_tensor.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
print("Verifying results...")
|
||||
o_ref = run_torch_fmha(
|
||||
q_ref, k_ref, v_ref, scale_softmax, scale_output, has_casual_mask
|
||||
q_ref, k_ref, v_ref, scale_softmax, scale_output, is_causal
|
||||
)
|
||||
|
||||
if o_ref.is_nested:
|
||||
@ -3095,6 +3082,76 @@ def run_fmha_and_verify(
|
||||
torch.testing.assert_close(o_result, o_ref, atol=tolerance, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
def generate_tensors():
|
||||
_, q_tensor_workspace, _ = create_and_pad_tensor(
|
||||
qo_shape,
|
||||
qo_padding,
|
||||
in_dtype,
|
||||
s_cumsum=cum_seqlen_q_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
_, k_tensor_workspace, _ = create_and_pad_tensor(
|
||||
kv_shape,
|
||||
kv_padding,
|
||||
in_dtype,
|
||||
s_cumsum=cum_seqlen_k_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
_, v_tensor_workspace, _ = create_and_pad_tensor(
|
||||
kv_shape,
|
||||
kv_padding,
|
||||
in_dtype,
|
||||
s_cumsum=cum_seqlen_k_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
_, o_tensor_workspace, _ = create_and_pad_tensor(
|
||||
qo_shape,
|
||||
qo_padding,
|
||||
out_dtype,
|
||||
s_cumsum=cum_seqlen_q_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
return testing.JitArguments(
|
||||
q_tensor_workspace.iterator,
|
||||
k_tensor_workspace.iterator,
|
||||
v_tensor_workspace.iterator,
|
||||
o_tensor_workspace.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
q_torch_effective = q_torch.values() if q_torch.is_nested else q_torch
|
||||
k_torch_effective = k_torch.values() if k_torch.is_nested else k_torch
|
||||
v_torch_effective = v_torch.values() if v_torch.is_nested else v_torch
|
||||
o_torch_effective = o_torch.values() if o_torch.is_nested else o_torch
|
||||
one_workspace_bytes = (
|
||||
q_torch_effective.numel() * q_torch_effective.element_size()
|
||||
+ k_torch_effective.numel() * k_torch_effective.element_size()
|
||||
+ v_torch_effective.numel() * v_torch_effective.element_size()
|
||||
+ o_torch_effective.numel() * o_torch_effective.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_fmha,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str):
|
||||
@ -3185,7 +3242,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--has_casual_mask",
|
||||
"--is_causal",
|
||||
action="store_true",
|
||||
help="Whether to use casual mask",
|
||||
)
|
||||
@ -3263,6 +3320,13 @@ if __name__ == "__main__":
|
||||
help="Skip reference check",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if len(args.q_shape) != 4:
|
||||
@ -3279,7 +3343,7 @@ if __name__ == "__main__":
|
||||
|
||||
torch.manual_seed(1111)
|
||||
|
||||
run_fmha_and_verify(
|
||||
run(
|
||||
args.q_shape,
|
||||
args.k_shape,
|
||||
args.in_dtype,
|
||||
@ -3288,7 +3352,7 @@ if __name__ == "__main__":
|
||||
args.pv_acc_dtype,
|
||||
args.mma_tiler_mn,
|
||||
args.is_persistent,
|
||||
args.has_casual_mask,
|
||||
args.is_causal,
|
||||
args.scale_q,
|
||||
args.scale_k,
|
||||
args.scale_v,
|
||||
@ -3298,6 +3362,7 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
|
||||
print("PASS")
|
||||
|
||||
@ -36,6 +36,7 @@ import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
@ -157,7 +158,7 @@ class GroupedGemmKernel:
|
||||
self.tmem_ptr_sync_bar_id = 2
|
||||
# Barrier ID used by MMA/TMA warps to signal A/B tensormap initialization completion
|
||||
self.tensormap_ab_init_bar_id = 4
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
self.num_tma_load_bytes = 0
|
||||
|
||||
def _setup_attributes(self):
|
||||
@ -951,7 +952,7 @@ class GroupedGemmKernel:
|
||||
# Specialized MMA warp
|
||||
#
|
||||
if warp_idx == self.mma_warp_id:
|
||||
# initilize tensormap A, B for TMA warp
|
||||
# initialize tensormap A, B for TMA warp
|
||||
if cutlass.const_expr(self.delegate_tensormap_ab_init):
|
||||
tensormap_manager.init_tensormap_from_atom(
|
||||
tma_atom_a, tensormap_a_init_ptr, self.mma_warp_id
|
||||
@ -1540,11 +1541,7 @@ class GroupedGemmKernel:
|
||||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy(
|
||||
copy_atom_r2s,
|
||||
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_copy_t2r.tiler_mn,
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
@ -1815,7 +1812,136 @@ class GroupedGemmKernel:
|
||||
tensor_memory_management_bytes = 12
|
||||
|
||||
|
||||
def run_grouped_gemm(
|
||||
# Create tensor and return the pointer, tensor, and stride
|
||||
def create_tensor_and_stride(
|
||||
l: int,
|
||||
mode0: int,
|
||||
mode1: int,
|
||||
is_mode0_major: bool,
|
||||
dtype: type[cutlass.Numeric],
|
||||
is_dynamic_layout: bool = True,
|
||||
torch_tensor_cpu: torch.Tensor = None,
|
||||
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
|
||||
"""Create a GPU tensor from scratch or based on an existing CPU tensor.
|
||||
|
||||
:param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one.
|
||||
:type torch_tensor_cpu: torch.Tensor, optional
|
||||
"""
|
||||
if torch_tensor_cpu is None:
|
||||
# Create new CPU tensor
|
||||
torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype)
|
||||
|
||||
# Create GPU tensor from CPU tensor (new or existing)
|
||||
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
|
||||
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
|
||||
)
|
||||
return (
|
||||
torch_tensor.data_ptr(),
|
||||
torch_tensor,
|
||||
cute_tensor,
|
||||
torch_tensor_cpu,
|
||||
torch_tensor.stride()[:-1],
|
||||
)
|
||||
|
||||
|
||||
def create_tensors_for_all_groups(
|
||||
problem_sizes_mnkl: List[tuple[int, int, int, int]],
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
torch_fp32_tensors_abc: List[List[torch.Tensor]] = None,
|
||||
) -> tuple[
|
||||
List[List[int]],
|
||||
List[List[torch.Tensor]],
|
||||
List[tuple],
|
||||
List[List[tuple]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
if torch_fp32_tensors_abc is not None and len(torch_fp32_tensors_abc) != len(
|
||||
problem_sizes_mnkl
|
||||
):
|
||||
raise ValueError("torch_fp32_tensors_abc must have one entry per group")
|
||||
|
||||
# Initialize lists to store tensors for all groups
|
||||
new_torch_fp32_tensors_abc = (
|
||||
[] if torch_fp32_tensors_abc is None else torch_fp32_tensors_abc
|
||||
)
|
||||
torch_tensors_abc = []
|
||||
cute_tensors_abc = []
|
||||
strides_abc = []
|
||||
ptrs_abc = []
|
||||
|
||||
# Iterate through all groups and create tensors for each group
|
||||
for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl):
|
||||
# Get existing CPU tensors if available, otherwise None
|
||||
existing_cpu_a = (
|
||||
torch_fp32_tensors_abc[group_idx][0] if torch_fp32_tensors_abc else None
|
||||
)
|
||||
existing_cpu_b = (
|
||||
torch_fp32_tensors_abc[group_idx][1] if torch_fp32_tensors_abc else None
|
||||
)
|
||||
existing_cpu_c = (
|
||||
torch_fp32_tensors_abc[group_idx][2] if torch_fp32_tensors_abc else None
|
||||
)
|
||||
|
||||
# Create tensors (reusing CPU tensors if provided)
|
||||
(
|
||||
ptr_a,
|
||||
torch_tensor_a,
|
||||
cute_tensor_a,
|
||||
tensor_fp32_a,
|
||||
stride_mk_a,
|
||||
) = create_tensor_and_stride(
|
||||
l, m, k, a_major == "m", ab_dtype, torch_tensor_cpu=existing_cpu_a
|
||||
)
|
||||
(
|
||||
ptr_b,
|
||||
torch_tensor_b,
|
||||
cute_tensor_b,
|
||||
tensor_fp32_b,
|
||||
stride_nk_b,
|
||||
) = create_tensor_and_stride(
|
||||
l, n, k, b_major == "n", ab_dtype, torch_tensor_cpu=existing_cpu_b
|
||||
)
|
||||
(
|
||||
ptr_c,
|
||||
torch_tensor_c,
|
||||
cute_tensor_c,
|
||||
tensor_fp32_c,
|
||||
stride_mn_c,
|
||||
) = create_tensor_and_stride(
|
||||
l, m, n, c_major == "m", c_dtype, torch_tensor_cpu=existing_cpu_c
|
||||
)
|
||||
|
||||
# Only append to new_torch_fp32_tensors_abc if we created new CPU tensors
|
||||
if torch_fp32_tensors_abc is None:
|
||||
new_torch_fp32_tensors_abc.append(
|
||||
[tensor_fp32_a, tensor_fp32_b, tensor_fp32_c]
|
||||
)
|
||||
|
||||
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
|
||||
torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
|
||||
strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c])
|
||||
cute_tensors_abc.append(
|
||||
(
|
||||
cute_tensor_a,
|
||||
cute_tensor_b,
|
||||
cute_tensor_c,
|
||||
)
|
||||
)
|
||||
|
||||
return (
|
||||
ptrs_abc,
|
||||
torch_tensors_abc,
|
||||
cute_tensors_abc,
|
||||
strides_abc,
|
||||
new_torch_fp32_tensors_abc,
|
||||
)
|
||||
|
||||
|
||||
def run(
|
||||
num_groups: int,
|
||||
problem_sizes_mnkl: tuple[int, int, int, int],
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
@ -1832,8 +1958,16 @@ def run_grouped_gemm(
|
||||
warmup_iterations: int,
|
||||
iterations: int,
|
||||
skip_ref_check: bool,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Run grouped GEMM example with specified configurations."""
|
||||
"""Run grouped GEMM example with specified configurations.
|
||||
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||||
:type use_cold_l2: bool, optional
|
||||
:return: Execution time of the GEMM kernel in microseconds
|
||||
:rtype: float
|
||||
"""
|
||||
print(f"Running Blackwell Grouped GEMM test with:")
|
||||
print(f"{num_groups} groups")
|
||||
for i, (m, n, k, l) in enumerate(problem_sizes_mnkl):
|
||||
@ -1847,6 +1981,7 @@ def run_grouped_gemm(
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||||
|
||||
# Skip unsupported types
|
||||
if ab_dtype not in {
|
||||
@ -1902,66 +2037,22 @@ def run_grouped_gemm(
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("GPU is required to run this example!")
|
||||
|
||||
# Create tensor and return the pointer, tensor, and stride
|
||||
def create_tensor_and_stride(
|
||||
l: int,
|
||||
mode0: int,
|
||||
mode1: int,
|
||||
is_mode0_major: bool,
|
||||
dtype: type[cutlass.Numeric],
|
||||
is_dynamic_layout: bool = True,
|
||||
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
|
||||
torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype)
|
||||
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
|
||||
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
|
||||
)
|
||||
return (
|
||||
torch_tensor.data_ptr(),
|
||||
torch_tensor,
|
||||
cute_tensor,
|
||||
torch_tensor_cpu,
|
||||
torch_tensor.stride()[:-1],
|
||||
)
|
||||
# Create tensors for all groups using the new function
|
||||
(
|
||||
ptrs_abc,
|
||||
torch_tensors_abc,
|
||||
cute_tensors_abc,
|
||||
strides_abc,
|
||||
torch_fp32_tensors_abc,
|
||||
) = create_tensors_for_all_groups(
|
||||
problem_sizes_mnkl,
|
||||
ab_dtype,
|
||||
c_dtype,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
)
|
||||
|
||||
# iterate all groups and create tensors for each group
|
||||
torch_fp32_tensors_abc = []
|
||||
torch_tensors_abc = []
|
||||
cute_tensors_abc = []
|
||||
strides_abc = []
|
||||
ptrs_abc = []
|
||||
for _, (m, n, k, l) in enumerate(problem_sizes_mnkl):
|
||||
(
|
||||
ptr_a,
|
||||
torch_tensor_a,
|
||||
cute_tensor_a,
|
||||
tensor_fp32_a,
|
||||
stride_mk_a,
|
||||
) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype)
|
||||
(
|
||||
ptr_b,
|
||||
torch_tensor_b,
|
||||
cute_tensor_b,
|
||||
tensor_fp32_b,
|
||||
stride_nk_b,
|
||||
) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype)
|
||||
(
|
||||
ptr_c,
|
||||
torch_tensor_c,
|
||||
cute_tensor_c,
|
||||
tensor_fp32_c,
|
||||
stride_mn_c,
|
||||
) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype)
|
||||
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
|
||||
torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
|
||||
torch_fp32_tensors_abc.append([tensor_fp32_a, tensor_fp32_b, tensor_fp32_c])
|
||||
strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c])
|
||||
cute_tensors_abc.append(
|
||||
(
|
||||
cute_tensor_a,
|
||||
cute_tensor_b,
|
||||
cute_tensor_c,
|
||||
)
|
||||
)
|
||||
# Choose A, B, C with the smallest size to create initial tensormaps
|
||||
key_size_a = lambda item: item[1][0] * item[1][2]
|
||||
key_size_b = lambda item: item[1][1] * item[1][2]
|
||||
@ -2078,36 +2169,19 @@ def run_grouped_gemm(
|
||||
current_stream,
|
||||
)
|
||||
|
||||
# Launch GPU kernel
|
||||
# Warm up
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_grouped_gemm(
|
||||
initial_cute_tensors_abc[0],
|
||||
initial_cute_tensors_abc[1],
|
||||
initial_cute_tensors_abc[2],
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc,
|
||||
tensor_of_ptrs_abc,
|
||||
tensor_of_tensormap,
|
||||
current_stream,
|
||||
)
|
||||
# Execution
|
||||
for i in range(iterations):
|
||||
compiled_grouped_gemm(
|
||||
initial_cute_tensors_abc[0],
|
||||
initial_cute_tensors_abc[1],
|
||||
initial_cute_tensors_abc[2],
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc,
|
||||
tensor_of_ptrs_abc,
|
||||
tensor_of_tensormap,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Compute reference result
|
||||
if not skip_ref_check:
|
||||
compiled_grouped_gemm(
|
||||
initial_cute_tensors_abc[0],
|
||||
initial_cute_tensors_abc[1],
|
||||
initial_cute_tensors_abc[2],
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc,
|
||||
tensor_of_ptrs_abc,
|
||||
tensor_of_tensormap,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
# Compute reference result
|
||||
for i, (a, b, c) in enumerate(torch_tensors_abc):
|
||||
ref = torch.einsum(
|
||||
"mkl,nkl->mnl",
|
||||
@ -2122,6 +2196,102 @@ def run_grouped_gemm(
|
||||
rtol=1e-05,
|
||||
)
|
||||
|
||||
def generate_tensors():
|
||||
# Reuse existing CPU tensors and create new GPU tensors from them
|
||||
(
|
||||
ptrs_abc_workspace,
|
||||
torch_tensors_abc_workspace,
|
||||
cute_tensors_abc_workspace,
|
||||
strides_abc_workspace,
|
||||
_,
|
||||
) = create_tensors_for_all_groups(
|
||||
problem_sizes_mnkl,
|
||||
ab_dtype,
|
||||
c_dtype,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
torch_fp32_tensors_abc,
|
||||
)
|
||||
|
||||
initial_cute_tensors_abc_workspace = [
|
||||
cute_tensors_abc_workspace[min_a_idx][0], # A with smallest (m, k)
|
||||
cute_tensors_abc_workspace[min_b_idx][1], # B with smallest (n, k)
|
||||
cute_tensors_abc_workspace[min_c_idx][2], # C with smallest (m, n)
|
||||
]
|
||||
|
||||
# Create new tensors for this workspace
|
||||
tensor_of_strides_abc_workspace, _ = cutlass_torch.cute_tensor_like(
|
||||
torch.tensor(strides_abc_workspace, dtype=torch.int32),
|
||||
cutlass.Int32,
|
||||
is_dynamic_layout=False,
|
||||
assumed_align=16,
|
||||
)
|
||||
|
||||
tensor_of_ptrs_abc_workspace, _ = cutlass_torch.cute_tensor_like(
|
||||
torch.tensor(ptrs_abc_workspace, dtype=torch.int64),
|
||||
cutlass.Int64,
|
||||
is_dynamic_layout=False,
|
||||
assumed_align=16,
|
||||
)
|
||||
|
||||
tensormap_workspace, _ = cutlass_torch.cute_tensor_like(
|
||||
torch.empty(tensormap_shape, dtype=torch.int64),
|
||||
cutlass.Int64,
|
||||
is_dynamic_layout=False,
|
||||
)
|
||||
|
||||
return testing.JitArguments(
|
||||
initial_cute_tensors_abc_workspace[0],
|
||||
initial_cute_tensors_abc_workspace[1],
|
||||
initial_cute_tensors_abc_workspace[2],
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc_workspace,
|
||||
tensor_of_ptrs_abc_workspace,
|
||||
tensormap_workspace,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
sum(
|
||||
[
|
||||
sum(
|
||||
[
|
||||
torch_tensor.numel() * torch_tensor.element_size()
|
||||
for torch_tensor in group_tensors
|
||||
]
|
||||
)
|
||||
for group_tensors in torch_tensors_abc
|
||||
]
|
||||
)
|
||||
+
|
||||
# Add size of strides tensor
|
||||
tensor_of_strides_abc_torch.numel()
|
||||
* tensor_of_strides_abc_torch.element_size()
|
||||
+
|
||||
# Add size of ptrs tensor
|
||||
tensor_of_ptrs_abc_torch.numel() * tensor_of_ptrs_abc_torch.element_size()
|
||||
+
|
||||
# Add size of tensormap tensor
|
||||
tensor_of_tensormap_torch.numel() * tensor_of_tensormap_torch.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_grouped_gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -2218,6 +2388,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -2248,7 +2424,7 @@ if __name__ == "__main__":
|
||||
|
||||
torch.manual_seed(2025)
|
||||
|
||||
run_grouped_gemm(
|
||||
run(
|
||||
args.num_groups,
|
||||
args.problem_sizes_mnkl,
|
||||
args.ab_dtype,
|
||||
@ -2265,5 +2441,6 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
@ -29,13 +29,14 @@
|
||||
|
||||
import argparse
|
||||
from typing import List, Type, Tuple, Optional
|
||||
from cuda import cuda
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
@ -43,13 +44,16 @@ import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
from .mamba2_ssd_reference import (
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent))
|
||||
from mamba2_ssd_reference import (
|
||||
ssd_reference_fp32_all,
|
||||
ssd_reference_lowprecision_intermediates,
|
||||
analyze_relative_diffs,
|
||||
)
|
||||
|
||||
from .mamba2_ssd_tile_scheduler import (
|
||||
from mamba2_ssd_tile_scheduler import (
|
||||
Mamba2SSDTileSchedulerParams,
|
||||
Mamba2SSDTileScheduler,
|
||||
)
|
||||
@ -122,7 +126,7 @@ class SSDKernel:
|
||||
*self.epilog_warp_id,
|
||||
)
|
||||
)
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
# Named barriers
|
||||
self.pre_inter_sync_bar_id = 1
|
||||
@ -1522,7 +1526,10 @@ class SSDKernel:
|
||||
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N)
|
||||
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE)
|
||||
tiled_r2s_b, tBrB_r2s, tBsB_r2s = self.pre_inter_smem_store_and_partition_b(
|
||||
local_tidx, smem_bt_internal_, tiled_s2r_b, tBrB_s2r
|
||||
local_tidx,
|
||||
smem_bt_internal_,
|
||||
tiled_s2r_b,
|
||||
tBrB_s2r,
|
||||
)
|
||||
|
||||
# (MMA, MMA_M, MMA_K, INPUT_STAGE)
|
||||
@ -3053,7 +3060,7 @@ class SSDKernel:
|
||||
|
||||
# SegSum
|
||||
# fadd2 + fsel + fmul2/mufu + fmul2
|
||||
for subtile_idx in range(0, cute.size(tTR_rQ), 2):
|
||||
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
|
||||
(
|
||||
tCompute[subtile_idx],
|
||||
tCompute[subtile_idx + 1],
|
||||
@ -3061,11 +3068,11 @@ class SSDKernel:
|
||||
(tCrDeltaA_Col[subtile_idx], tCrDeltaA_Col[subtile_idx + 1]),
|
||||
(-tCrDeltaA_Row[subtile_idx], -tCrDeltaA_Row[subtile_idx + 1]),
|
||||
)
|
||||
for subtile_idx in range(cute.size(tTR_rQ)):
|
||||
for subtile_idx in cutlass.range(cute.size(tTR_rQ), unroll_full=True):
|
||||
m, n = tCoord[subtile_idx]
|
||||
if m < n:
|
||||
tCompute[subtile_idx] = cutlass.Float32(-float("inf"))
|
||||
for subtile_idx in range(0, cute.size(tTR_rQ), 2):
|
||||
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
|
||||
# TODO: use math.exp directly
|
||||
(
|
||||
tCompute[subtile_idx],
|
||||
@ -3130,11 +3137,7 @@ class SSDKernel:
|
||||
dtype,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
tiled_r2s_b = cute.make_tiled_copy(
|
||||
copy_atom_r2s_b,
|
||||
layout_tv=tiled_s2r_b.layout_tv_tiled,
|
||||
tiler_mn=tiled_s2r_b.tiler_mn,
|
||||
)
|
||||
tiled_r2s_b = cute.make_tiled_copy_S(copy_atom_r2s_b, tiled_s2r_b)
|
||||
thr_r2s_b = tiled_r2s_b.get_slice(local_tidx)
|
||||
|
||||
# Partition shared tensor for smem store Bt
|
||||
@ -3333,17 +3336,24 @@ class SSDKernel:
|
||||
)
|
||||
|
||||
|
||||
def run_ssd(
|
||||
def run(
|
||||
gbehcdln: Tuple[int, int, int, int, int, int, int, int],
|
||||
io_dtype: Type[cutlass.Numeric],
|
||||
cumsum_delta_dtype: Type[cutlass.Numeric],
|
||||
acc_dtype: Type[cutlass.Numeric],
|
||||
has_d: bool,
|
||||
d_has_hdim: bool,
|
||||
fuse_scale_d: str,
|
||||
tolerance: float,
|
||||
print_rtol_stats: bool,
|
||||
ref_lower_precision: bool,
|
||||
warmup_iterations: int,
|
||||
iterations: int,
|
||||
skip_ref_check: bool,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
has_d = fuse_scale_d != "none"
|
||||
d_has_hdim = fuse_scale_d == "vector"
|
||||
|
||||
print(f"Running B100 Mamba2 SSD with:")
|
||||
print(f"GBEHCDLN: {gbehcdln}")
|
||||
print(
|
||||
@ -3353,6 +3363,10 @@ def run_ssd(
|
||||
f"Has D (True means fuse Y+=X*D): {has_d}, D has Hdim (True means D.shape DxEH, False means 1xEH): {d_has_hdim}"
|
||||
)
|
||||
print(f"Tolerance: {tolerance}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||||
|
||||
# Unpack parameters
|
||||
G, B, E, H, C, D, L, N = gbehcdln
|
||||
@ -3515,39 +3529,146 @@ def run_ssd(
|
||||
stream,
|
||||
)
|
||||
|
||||
# Launch compiled ssd kernel
|
||||
compiled_ssd(
|
||||
x_tensor,
|
||||
cumsum_delta_tensor,
|
||||
delta_tensor,
|
||||
b_tensor,
|
||||
c_tensor,
|
||||
y_tensor,
|
||||
fstate_tensor,
|
||||
d_tensor,
|
||||
stream,
|
||||
# Launch compiled ssd kernel for reference check
|
||||
if not skip_ref_check:
|
||||
compiled_ssd(
|
||||
x_tensor,
|
||||
cumsum_delta_tensor,
|
||||
delta_tensor,
|
||||
b_tensor,
|
||||
c_tensor,
|
||||
y_tensor,
|
||||
fstate_tensor,
|
||||
d_tensor,
|
||||
stream,
|
||||
)
|
||||
|
||||
# Reference check
|
||||
if print_rtol_stats:
|
||||
print("\nY's Relative diffs:")
|
||||
analyze_relative_diffs(
|
||||
y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype))
|
||||
)
|
||||
print("\nFstate's Relative diffs:")
|
||||
analyze_relative_diffs(
|
||||
fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype))
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
y_torch.cpu(),
|
||||
y_ref.to(cutlass_torch.dtype(io_dtype)),
|
||||
atol=tolerance,
|
||||
rtol=1e-02,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
fstate_torch.cpu(),
|
||||
fstate_ref.to(cutlass_torch.dtype(io_dtype)),
|
||||
atol=tolerance,
|
||||
rtol=1e-05,
|
||||
)
|
||||
|
||||
def generate_tensors():
|
||||
# Reuse existing CPU reference tensors and create new GPU tensors from them
|
||||
_, x_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, D, C, L],
|
||||
[2, 4, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=x_ref,
|
||||
dynamic_modes=[2, 3, 4],
|
||||
)
|
||||
_, cumsum_delta_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, C, L],
|
||||
[3, 2, 1, 0],
|
||||
cumsum_delta_dtype,
|
||||
ref_tensor=cumsum_delta_ref,
|
||||
dynamic_modes=[1, 2, 3],
|
||||
)
|
||||
_, delta_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, C, L],
|
||||
[3, 2, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=delta_ref,
|
||||
dynamic_modes=[1, 2, 3],
|
||||
)
|
||||
_, b_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, G, N, C, L],
|
||||
[4, 2, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=b_ref,
|
||||
dynamic_modes=[2, 3, 4],
|
||||
)
|
||||
_, c_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, G, N, C, L],
|
||||
[4, 2, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=c_ref,
|
||||
dynamic_modes=[2, 3, 4],
|
||||
)
|
||||
_, y_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, D, C, L],
|
||||
[4, 2, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=y_ref,
|
||||
dynamic_modes=[2, 3, 4],
|
||||
)
|
||||
_, fstate_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, D, N],
|
||||
[2, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=fstate_ref,
|
||||
dynamic_modes=[2, 3],
|
||||
)
|
||||
|
||||
if has_d:
|
||||
_, d_tensor_new, _ = create_and_permute_tensor(
|
||||
[EH, D if d_has_hdim else 1],
|
||||
[1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=d_ref,
|
||||
dynamic_modes=[1],
|
||||
)
|
||||
else:
|
||||
d_tensor_new = d_tensor
|
||||
|
||||
return testing.JitArguments(
|
||||
x_tensor_new,
|
||||
cumsum_delta_tensor_new,
|
||||
delta_tensor_new,
|
||||
b_tensor_new,
|
||||
c_tensor_new,
|
||||
y_tensor_new,
|
||||
fstate_tensor_new,
|
||||
d_tensor_new,
|
||||
stream,
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
x_torch.numel() * x_torch.element_size()
|
||||
+ cumsum_delta_torch.numel() * cumsum_delta_torch.element_size()
|
||||
+ delta_torch.numel() * delta_torch.element_size()
|
||||
+ b_torch.numel() * b_torch.element_size()
|
||||
+ c_torch.numel() * c_torch.element_size()
|
||||
+ y_torch.numel() * y_torch.element_size()
|
||||
+ fstate_torch.numel() * fstate_torch.element_size()
|
||||
)
|
||||
if has_d:
|
||||
one_workspace_bytes += d_torch.numel() * d_torch.element_size()
|
||||
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_ssd,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
# Reference check
|
||||
if print_rtol_stats:
|
||||
print("\nY's Relative diffs:")
|
||||
analyze_relative_diffs(y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype)))
|
||||
print("\nFstate's Relative diffs:")
|
||||
analyze_relative_diffs(
|
||||
fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype))
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
y_torch.cpu(),
|
||||
y_ref.to(cutlass_torch.dtype(io_dtype)),
|
||||
atol=tolerance,
|
||||
rtol=1e-02,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
fstate_torch.cpu(),
|
||||
fstate_ref.to(cutlass_torch.dtype(io_dtype)),
|
||||
atol=tolerance,
|
||||
rtol=1e-05,
|
||||
)
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -3586,15 +3707,53 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref_lower_precision",
|
||||
type=bool,
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Use lower precision for reference check",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-ref_lower_precision",
|
||||
action="store_false",
|
||||
dest="ref_lower_precision",
|
||||
default=False,
|
||||
help="Disable lower precision for reference check",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=5e-02, help="Tolerance for validation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_rtol_stats", type=bool, default=True, help="Print rtol stats"
|
||||
"--print_rtol_stats",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Enable print rtol stats",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-print_rtol_stats",
|
||||
action="store_false",
|
||||
dest="print_rtol_stats",
|
||||
default=False,
|
||||
help="Disable print rtol stats",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_iterations",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of warmup iterations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of iterations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -3602,18 +3761,18 @@ if __name__ == "__main__":
|
||||
if len(args.gbehcdln) != 8:
|
||||
parser.error("--gbehcdln must contain exactly 8 values")
|
||||
|
||||
has_d = args.fuse_scale_d != "none"
|
||||
d_has_hdim = args.fuse_scale_d == "vector"
|
||||
|
||||
run_ssd(
|
||||
run(
|
||||
args.gbehcdln,
|
||||
args.io_dtype,
|
||||
args.cumsum_delta_dtype,
|
||||
args.acc_dtype,
|
||||
has_d,
|
||||
d_has_hdim,
|
||||
args.fuse_scale_d,
|
||||
args.tolerance,
|
||||
args.print_rtol_stats,
|
||||
args.ref_lower_precision,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
Reference in New Issue
Block a user