v4.1 release update v2. (#2481)

This commit is contained in:
Junkai-Wu
2025-07-22 10:03:55 +08:00
committed by GitHub
parent 9baa06dd57
commit fd6cfe1ed0
179 changed files with 7878 additions and 1286 deletions

View File

@ -35,6 +35,7 @@ import torch
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.torch as cutlass_torch
@ -166,6 +167,24 @@ def parse_arguments() -> argparse.Namespace:
parser.add_argument(
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
)
parser.add_argument(
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
)
parser.add_argument(
"--iterations",
type=int,
default=1,
help="Number of iterations to run the kernel",
)
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
@ -264,7 +283,7 @@ class HopperWgmmaGemmKernel:
self.mma_warp_groups = math.prod(self.atom_layout_mnk)
self.num_threads_per_warp_group = 128
self.threads_per_cta = self.mma_warp_groups * self.num_threads_per_warp_group
self.smem_capacity = sm90_utils.SMEM_CAPACITY["sm90"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
self.ab_stage = None
self.epi_stage = None
@ -1309,7 +1328,7 @@ class HopperWgmmaGemmKernel:
}:
is_valid = False
# tested acc_dtype
if acc_dtype != cutlass.Float32:
if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
is_valid = False
# tested c_dtype
if c_dtype not in {
@ -1335,7 +1354,7 @@ class HopperWgmmaGemmKernel:
return is_valid
def run_dense_gemm(
def run(
mnkl: Tuple[int, int, int, int],
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
@ -1347,9 +1366,43 @@ def run_dense_gemm(
tile_shape_mnk: Tuple[int, int, int],
cluster_shape_mn: Tuple[int, int],
tolerance: float,
warmup_iterations: int,
iterations: int,
skip_ref_check: bool,
use_cold_l2: bool = False,
**kwargs,
):
"""
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
:param mnkl: Problem size (M, N, K, L)
:type mnkl: Tuple[int, int, int, int]
:param a_dtype: Data type for input tensor A
:type a_dtype: Type[cutlass.Numeric]
:param b_dtype: Data type for input tensor B
:type b_dtype: Type[cutlass.Numeric]
:param c_dtype: Data type for output tensor C
:type c_dtype: Type[cutlass.Numeric]
:param acc_dtype: Data type for accumulation during matrix multiplication
:type acc_dtype: Type[cutlass.Numeric]
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
:type a_major/b_major/c_major: str
:param tile_shape_mnk: CTA tile shape (M, N, K)
:type tile_shape_mnk: Tuple[int, int, int]
:param cluster_shape_mn: Cluster shape (M, N)
:type cluster_shape_mn: Tuple[int, int]
:param tolerance: Tolerance value for reference validation comparison
:type tolerance: float
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
:type warmup_iterations: int, optional
:param iterations: Number of benchmark iterations to run, defaults to 1
:type iterations: int, optional
:param skip_ref_check: Whether to skip reference result validation, defaults to False
:type skip_ref_check: bool, optional
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
:type use_cold_l2: bool, optional
:return: Execution time of the GEMM kernel in microseconds
:rtype: float
"""
print(f"Running Hopper Dense GEMM with:")
@ -1360,6 +1413,10 @@ def run_dense_gemm(
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {use_cold_l2}")
# Unpack parameters
m, n, k, l = mnkl
@ -1437,46 +1494,76 @@ def run_dense_gemm(
stream = cuda.CUstream(torch_stream.cuda_stream)
# compile gemm kernel
compiled_gemm = cute.compile(gemm, mA, mB, mC, stream)
# execution
compiled_gemm(mA, mB, mC, stream)
torch.cuda.synchronize()
if not skip_ref_check:
# execution
compiled_gemm(mA, mB, mC, stream)
# Ref check
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
torch.cuda.synchronize()
if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
# m major: (l, n, m) -> (m, n, l)
# k major: (l, m, n) -> (m, n, l)
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
shape = (l, m, n) if c_major == "n" else (l, n, m)
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
shape,
torch.uint8,
permute_order=permute_order,
init_type=cutlass_torch.TensorInitType.SKIP,
).cuda()
# Create dtype cute tensor (gpu)
ref_c_tensor = from_dlpack(
f8_torch_tensor, assumed_align=16
).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
ref_c_tensor.element_type = c_dtype
ref_c_tensor = cutlass_torch.convert_cute_tensor(
ref,
ref_c_tensor,
c_dtype,
is_dynamic_layout=True,
# Ref check
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
# m major: (l, n, m) -> (m, n, l)
# n major: (l, m, n) -> (m, n, l)
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
shape = (l, m, n) if c_major == "n" else (l, n, m)
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
shape,
torch.uint8,
permute_order=permute_order,
init_type=cutlass_torch.TensorInitType.SKIP,
).cuda()
# Create dtype cute tensor (gpu)
ref_c_tensor = from_dlpack(
f8_torch_tensor, assumed_align=16
).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
ref_c_tensor.element_type = c_dtype
ref_c_tensor = cutlass_torch.convert_cute_tensor(
ref,
ref_c_tensor,
c_dtype,
is_dynamic_layout=True,
)
ref_c = f8_torch_tensor.cpu()
else:
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
def generate_tensors():
_, mA_workspace, _ = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
_, mB_workspace, _ = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
_, mC_workspace, _ = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_torch.numel() * a_torch.element_size()
+ b_torch.numel() * b_torch.element_size()
+ c_torch.numel() * c_torch.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
ref_c = f8_torch_tensor.cpu()
else:
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
exec_time = testing.benchmark(
compiled_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
args = parse_arguments()
run_dense_gemm(
run(
args.mnkl,
args.a_dtype,
args.b_dtype,
@ -1488,5 +1575,9 @@ if __name__ == "__main__":
args.tile_shape_mnk,
args.cluster_shape_mn,
args.tolerance,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")