v4.1 release update v2. (#2481)

This commit is contained in:
Junkai-Wu
2025-07-22 10:03:55 +08:00
committed by GitHub
parent 9baa06dd57
commit fd6cfe1ed0
179 changed files with 7878 additions and 1286 deletions

View File

@ -51,7 +51,7 @@ This GEMM kernel supports the following features:
- Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations
- Threadblock rasterization to improve data re-use
- Supports multi-stage pipeline to overlap computation and memory access
- Implements shared memory buffering for epilogue to increase coalesed global memory access
- Implements shared memory buffering for epilogue to increase coalesced global memory access
This GEMM works as follows:
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies.
@ -214,7 +214,7 @@ class TensorOpGemm:
atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits
)
# Creates a synchonous copy atom and thread layouts for the epilogue
# Creates a synchronous copy atom and thread layouts for the epilogue
c_copy_bits = 128
atom_sync_copy = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
@ -550,16 +550,8 @@ class TensorOpGemm:
# Creates the tiled copy so that it matches the thread-value layout
# expected by the tiled mma
tiled_copy_s2r_A = cute.make_tiled_copy(
atom_copy_s2r_A,
layout_tv=tiled_mma.tv_layout_A_tiled,
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
)
tiled_copy_s2r_B = cute.make_tiled_copy(
atom_copy_s2r_B,
layout_tv=tiled_mma.tv_layout_B_tiled,
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
)
tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma)
tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma)
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
@ -836,8 +828,7 @@ class TensorOpGemm:
if major_mode == utils.LayoutEnum.ROW_MAJOR
else cute.make_layout((copy_elems, 1))
)
tiler_mn, layout_tv = cute.make_layout_tv(thread_layout, value_layout)
return cute.make_tiled_copy(atom_copy, layout_tv, tiler_mn)
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
def raster_tile(self, i, j, f):
new_i = i // f
@ -845,20 +836,33 @@ class TensorOpGemm:
return (new_i, new_j)
def run_tensor_op_gemm(
def run(
a_major: str,
b_major: str,
c_major: str,
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
problem_shape: Tuple[int, int, int, int],
mnkl: Tuple[int, int, int, int],
atom_layout_mnk: Tuple[int, int, int],
warmup_iterations: int = 2,
iterations: int = 100,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
):
M, N, K, L = problem_shape
print(f"Running Ampere tensor core GEMM example:")
print(f"mnkl: {mnkl}")
print(
f"A dtype: {ab_dtype}, B dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}"
)
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
print(f"Atoms layout: {atom_layout_mnk}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {use_cold_l2}")
M, N, K, L = mnkl
# Create and permute tensor A/B/C
def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype):
@ -866,23 +870,28 @@ def run_tensor_op_gemm(
# else: (l, mode0, mode1) -> (mode0, mode1, l)
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
return (
torch_tensor = (
torch.empty(*shape, dtype=torch.int32)
.random_(-2, 2)
.to(dtype=dtype)
.to(dtype=cutlass_torch.dtype(dtype))
.permute(permute_order)
.cuda()
)
# assume input is 16B aligned
cute_tensor = (
from_dlpack(torch_tensor, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if not is_mode0_major else 0))
.mark_compact_shape_dynamic(
mode=(1 if not is_mode0_major else 0),
stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0),
divisibility=(128 // dtype.width),
)
)
return cute_tensor, torch_tensor
a = create_and_permute_tensor(
L, M, K, a_major == "m", cutlass_torch.dtype(ab_dtype)
)
b = create_and_permute_tensor(
L, N, K, b_major == "n", cutlass_torch.dtype(ab_dtype)
)
c = create_and_permute_tensor(L, M, N, c_major == "m", cutlass_torch.dtype(c_dtype))
ref = torch.einsum("mkl,nkl->mnl", a, b).to(cutlass_torch.dtype(c_dtype))
mA, a_torch = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
mB, b_torch = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
mC, c_torch = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
tensor_op_gemm = TensorOpGemm(
ab_dtype,
@ -891,56 +900,49 @@ def run_tensor_op_gemm(
atom_layout_mnk,
)
# assume input is 16B aligned
a_tensor = (
from_dlpack(a, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
.mark_compact_shape_dynamic(
mode=(1 if a_major == "k" else 0),
stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0),
divisibility=(128 // ab_dtype.width),
)
)
b_tensor = (
from_dlpack(b, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
.mark_compact_shape_dynamic(
mode=(1 if b_major == "k" else 0),
stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0),
divisibility=(128 // ab_dtype.width),
)
)
c_tensor = (
from_dlpack(c, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
.mark_compact_shape_dynamic(
mode=(1 if c_major == "n" else 0),
stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0),
divisibility=(128 // c_dtype.width),
)
)
print("Compiling kernel with cute.compile ...")
gemm = cute.compile(tensor_op_gemm, a_tensor, b_tensor, c_tensor)
compiled_gemm = cute.compile(tensor_op_gemm, mA, mB, mC)
print("Executing GEMM kernel...")
if not skip_ref_check:
ref = torch.einsum(
"mkl,nkl->mnl",
a_torch.to(dtype=torch.float32),
b_torch.to(dtype=torch.float32),
).to(cutlass_torch.dtype(c_dtype))
compiled_gemm(mA, mB, mC)
print("Verifying results...")
torch.testing.assert_close(c_torch.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
print("Results verified successfully!")
def generate_tensors():
a_workspace, _ = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
b_workspace, _ = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
c_workspace, _ = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
return testing.JitArguments(a_workspace, b_workspace, c_workspace)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_torch.numel() * a_torch.element_size()
+ b_torch.numel() * b_torch.element_size()
+ c_torch.numel() * c_torch.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
avg_time_us = testing.benchmark(
gemm,
kernel_arguments=testing.JitArguments(a_tensor, b_tensor, c_tensor),
compiled_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
iterations=iterations,
use_cuda_graphs=False,
)
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
if not skip_ref_check:
gemm(a_tensor, b_tensor, c_tensor)
print("Verifying results...")
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
print("Results verified successfully!")
return avg_time_us # Return execution time in microseconds
if __name__ == "__main__":
@ -985,10 +987,15 @@ if __name__ == "__main__":
parser.add_argument("--warmup_iterations", default=2, type=int)
parser.add_argument("--iterations", default=100, type=int)
parser.add_argument("--skip_ref_check", action="store_true")
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
print("Running Ampere tensor core GEMM example:")
run_tensor_op_gemm(
run(
args.a_major,
args.b_major,
args.c_major,
@ -1000,5 +1007,6 @@ if __name__ == "__main__":
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")