v4.1 release update v2. (#2481)
This commit is contained in:
@ -35,6 +35,7 @@ import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.torch as cutlass_torch
|
||||
@ -166,6 +167,24 @@ def parse_arguments() -> argparse.Namespace:
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of iterations to run the kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -264,7 +283,7 @@ class HopperWgmmaGemmKernel:
|
||||
self.mma_warp_groups = math.prod(self.atom_layout_mnk)
|
||||
self.num_threads_per_warp_group = 128
|
||||
self.threads_per_cta = self.mma_warp_groups * self.num_threads_per_warp_group
|
||||
self.smem_capacity = sm90_utils.SMEM_CAPACITY["sm90"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
|
||||
|
||||
self.ab_stage = None
|
||||
self.epi_stage = None
|
||||
@ -1309,7 +1328,7 @@ class HopperWgmmaGemmKernel:
|
||||
}:
|
||||
is_valid = False
|
||||
# tested acc_dtype
|
||||
if acc_dtype != cutlass.Float32:
|
||||
if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
|
||||
is_valid = False
|
||||
# tested c_dtype
|
||||
if c_dtype not in {
|
||||
@ -1335,7 +1354,7 @@ class HopperWgmmaGemmKernel:
|
||||
return is_valid
|
||||
|
||||
|
||||
def run_dense_gemm(
|
||||
def run(
|
||||
mnkl: Tuple[int, int, int, int],
|
||||
a_dtype: Type[cutlass.Numeric],
|
||||
b_dtype: Type[cutlass.Numeric],
|
||||
@ -1347,9 +1366,43 @@ def run_dense_gemm(
|
||||
tile_shape_mnk: Tuple[int, int, int],
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
tolerance: float,
|
||||
warmup_iterations: int,
|
||||
iterations: int,
|
||||
skip_ref_check: bool,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
|
||||
|
||||
:param mnkl: Problem size (M, N, K, L)
|
||||
:type mnkl: Tuple[int, int, int, int]
|
||||
:param a_dtype: Data type for input tensor A
|
||||
:type a_dtype: Type[cutlass.Numeric]
|
||||
:param b_dtype: Data type for input tensor B
|
||||
:type b_dtype: Type[cutlass.Numeric]
|
||||
:param c_dtype: Data type for output tensor C
|
||||
:type c_dtype: Type[cutlass.Numeric]
|
||||
:param acc_dtype: Data type for accumulation during matrix multiplication
|
||||
:type acc_dtype: Type[cutlass.Numeric]
|
||||
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
|
||||
:type a_major/b_major/c_major: str
|
||||
:param tile_shape_mnk: CTA tile shape (M, N, K)
|
||||
:type tile_shape_mnk: Tuple[int, int, int]
|
||||
:param cluster_shape_mn: Cluster shape (M, N)
|
||||
:type cluster_shape_mn: Tuple[int, int]
|
||||
:param tolerance: Tolerance value for reference validation comparison
|
||||
:type tolerance: float
|
||||
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
||||
:type warmup_iterations: int, optional
|
||||
:param iterations: Number of benchmark iterations to run, defaults to 1
|
||||
:type iterations: int, optional
|
||||
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
||||
:type skip_ref_check: bool, optional
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||||
:type use_cold_l2: bool, optional
|
||||
:return: Execution time of the GEMM kernel in microseconds
|
||||
:rtype: float
|
||||
"""
|
||||
|
||||
print(f"Running Hopper Dense GEMM with:")
|
||||
@ -1360,6 +1413,10 @@ def run_dense_gemm(
|
||||
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
||||
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
|
||||
print(f"Tolerance: {tolerance}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {use_cold_l2}")
|
||||
|
||||
# Unpack parameters
|
||||
m, n, k, l = mnkl
|
||||
@ -1437,46 +1494,76 @@ def run_dense_gemm(
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
# compile gemm kernel
|
||||
compiled_gemm = cute.compile(gemm, mA, mB, mC, stream)
|
||||
# execution
|
||||
compiled_gemm(mA, mB, mC, stream)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if not skip_ref_check:
|
||||
# execution
|
||||
compiled_gemm(mA, mB, mC, stream)
|
||||
|
||||
# Ref check
|
||||
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
|
||||
# m major: (l, n, m) -> (m, n, l)
|
||||
# k major: (l, m, n) -> (m, n, l)
|
||||
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
|
||||
shape = (l, m, n) if c_major == "n" else (l, n, m)
|
||||
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
||||
shape,
|
||||
torch.uint8,
|
||||
permute_order=permute_order,
|
||||
init_type=cutlass_torch.TensorInitType.SKIP,
|
||||
).cuda()
|
||||
# Create dtype cute tensor (gpu)
|
||||
ref_c_tensor = from_dlpack(
|
||||
f8_torch_tensor, assumed_align=16
|
||||
).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
ref_c_tensor.element_type = c_dtype
|
||||
ref_c_tensor = cutlass_torch.convert_cute_tensor(
|
||||
ref,
|
||||
ref_c_tensor,
|
||||
c_dtype,
|
||||
is_dynamic_layout=True,
|
||||
# Ref check
|
||||
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
|
||||
|
||||
if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
|
||||
# m major: (l, n, m) -> (m, n, l)
|
||||
# n major: (l, m, n) -> (m, n, l)
|
||||
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
|
||||
shape = (l, m, n) if c_major == "n" else (l, n, m)
|
||||
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
||||
shape,
|
||||
torch.uint8,
|
||||
permute_order=permute_order,
|
||||
init_type=cutlass_torch.TensorInitType.SKIP,
|
||||
).cuda()
|
||||
# Create dtype cute tensor (gpu)
|
||||
ref_c_tensor = from_dlpack(
|
||||
f8_torch_tensor, assumed_align=16
|
||||
).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
ref_c_tensor.element_type = c_dtype
|
||||
ref_c_tensor = cutlass_torch.convert_cute_tensor(
|
||||
ref,
|
||||
ref_c_tensor,
|
||||
c_dtype,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
ref_c = f8_torch_tensor.cpu()
|
||||
else:
|
||||
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
|
||||
|
||||
torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
|
||||
|
||||
def generate_tensors():
|
||||
_, mA_workspace, _ = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
|
||||
_, mB_workspace, _ = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
|
||||
_, mC_workspace, _ = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
|
||||
return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
a_torch.numel() * a_torch.element_size()
|
||||
+ b_torch.numel() * b_torch.element_size()
|
||||
+ c_torch.numel() * c_torch.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
ref_c = f8_torch_tensor.cpu()
|
||||
else:
|
||||
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
|
||||
|
||||
torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
|
||||
exec_time = testing.benchmark(
|
||||
compiled_gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
run_dense_gemm(
|
||||
run(
|
||||
args.mnkl,
|
||||
args.a_dtype,
|
||||
args.b_dtype,
|
||||
@ -1488,5 +1575,9 @@ if __name__ == "__main__":
|
||||
args.tile_shape_mnk,
|
||||
args.cluster_shape_mn,
|
||||
args.tolerance,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
Reference in New Issue
Block a user