v4.1 release update v2. (#2481)

This commit is contained in:
Junkai-Wu
2025-07-22 10:03:55 +08:00
committed by GitHub
parent 9baa06dd57
commit fd6cfe1ed0
179 changed files with 7878 additions and 1286 deletions

View File

@ -38,6 +38,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.cute.testing as testing
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.runtime import from_dlpack
@ -226,7 +227,7 @@ class PersistentDenseGemmKernel:
self.cta_sync_bar_id = 0
self.epilog_sync_bar_id = 1
self.tmem_ptr_sync_bar_id = 2
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
@ -1308,11 +1309,7 @@ class PersistentDenseGemmKernel:
copy_atom_r2s = sm100_utils.get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
)
tiled_copy_r2s = cute.make_tiled_copy(
copy_atom_r2s,
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
tiler_mn=tiled_copy_t2r.tiler_mn,
)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)
@ -1824,7 +1821,7 @@ class PersistentDenseGemmKernel:
return can_implement
def run_dense_gemm(
def run(
mnkl: Tuple[int, int, int, int],
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
@ -1832,17 +1829,58 @@ def run_dense_gemm(
a_major: str,
b_major: str,
c_major: str,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
use_2cta_instrs: bool,
use_tma_store: bool,
tolerance: float,
mma_tiler_mn: Tuple[int, int] = (256, 256),
cluster_shape_mn: Tuple[int, int] = (2, 1),
use_2cta_instrs: bool = True,
use_tma_store: bool = True,
tolerance: float = 1e-01,
warmup_iterations: int = 0,
iterations: int = 1,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
):
"""
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
"""Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking.
This function prepares input tensors, configures and launches the persistent GEMM kernel,
optionally performs reference validation, and benchmarks the execution performance.
:param mnkl: Problem size (M, N, K, L)
:type mnkl: Tuple[int, int, int, int]
:param ab_dtype: Data type for input tensors A and B
:type ab_dtype: Type[cutlass.Numeric]
:param c_dtype: Data type for output tensor C
:type c_dtype: Type[cutlass.Numeric]
:param acc_dtype: Data type for accumulation during matrix multiplication
:type acc_dtype: Type[cutlass.Numeric]
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
:type a_major/b_major/c_major: str
:param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the
default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters.
:type mma_tiler_mn: Tuple[int, int], optional
:param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
:type cluster_shape_mn: Tuple[int, int], optional
:param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
:type use_2cta_instrs: bool, optional
:param use_tma_store: Whether to use TMA store. If not specified in the decorator parameters, the autotuner will use
the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
:type use_tma_store: bool, optional
:param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
:type tolerance: float, optional
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
:type warmup_iterations: int, optional
:param iterations: Number of benchmark iterations to run, defaults to 1
:type iterations: int, optional
:param skip_ref_check: Whether to skip reference result validation, defaults to False
:type skip_ref_check: bool, optional
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
:type use_cold_l2: bool, optional
:raises RuntimeError: If CUDA GPU is not available
:raises ValueError: If the configuration is invalid or unsupported by the kernel
:return: Execution time of the GEMM kernel
:rtype: float
"""
print(f"Running Blackwell Persistent Dense GEMM test with:")
print(f"mnkl: {mnkl}")
@ -1855,6 +1893,7 @@ def run_dense_gemm(
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
# Unpack parameters
m, n, k, l = mnkl
@ -1931,15 +1970,15 @@ def run_dense_gemm(
is_dynamic_layout=is_dynamic_layout,
)
return f32_torch_tensor, cute_tensor, torch_tensor
return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu
a_ref, a_tensor, a_torch = create_and_permute_tensor(
a_ref, a_tensor, a_torch, a_torch_cpu = create_and_permute_tensor(
l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True
)
b_ref, b_tensor, b_torch = create_and_permute_tensor(
b_ref, b_tensor, b_torch, b_torch_cpu = create_and_permute_tensor(
l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True
)
c_ref, c_tensor, c_torch = create_and_permute_tensor(
c_ref, c_tensor, c_torch, c_torch_cpu = create_and_permute_tensor(
l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True
)
@ -1967,16 +2006,8 @@ def run_dense_gemm(
gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, current_stream
)
# Launch GPU kernel
# Warm up
for i in range(warmup_iterations):
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
# Execution
for i in range(iterations):
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
# Compute reference result
if not skip_ref_check:
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
if ab_dtype in {
cutlass.Int8,
cutlass.Uint8,
@ -2028,6 +2059,40 @@ def run_dense_gemm(
rtol=1e-05,
)
def generate_tensors():
a_tensor, _ = cutlass_torch.cute_tensor_like(
a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
b_tensor, _ = cutlass_torch.cute_tensor_like(
b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
c_tensor, _ = cutlass_torch.cute_tensor_like(
c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16
)
return testing.JitArguments(a_tensor, b_tensor, c_tensor, current_stream)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_torch_cpu.numel() * a_torch_cpu.element_size()
+ b_torch_cpu.numel() * b_torch_cpu.element_size()
+ c_torch_cpu.numel() * c_torch_cpu.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
@ -2090,6 +2155,12 @@ if __name__ == "__main__":
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
@ -2102,7 +2173,7 @@ if __name__ == "__main__":
if len(args.cluster_shape_mn) != 2:
parser.error("--cluster_shape_mn must contain exactly 2 values")
run_dense_gemm(
run(
args.mnkl,
args.ab_dtype,
args.c_dtype,
@ -2118,5 +2189,6 @@ if __name__ == "__main__":
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")