v4.1 release update v2. (#2481)
This commit is contained in:
@ -51,7 +51,7 @@ This GEMM kernel supports the following features:
|
||||
- Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations
|
||||
- Threadblock rasterization to improve data re-use
|
||||
- Supports multi-stage pipeline to overlap computation and memory access
|
||||
- Implements shared memory buffering for epilogue to increase coalesed global memory access
|
||||
- Implements shared memory buffering for epilogue to increase coalesced global memory access
|
||||
|
||||
This GEMM works as follows:
|
||||
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies.
|
||||
@ -214,7 +214,7 @@ class TensorOpGemm:
|
||||
atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits
|
||||
)
|
||||
|
||||
# Creates a synchonous copy atom and thread layouts for the epilogue
|
||||
# Creates a synchronous copy atom and thread layouts for the epilogue
|
||||
c_copy_bits = 128
|
||||
atom_sync_copy = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
@ -550,16 +550,8 @@ class TensorOpGemm:
|
||||
|
||||
# Creates the tiled copy so that it matches the thread-value layout
|
||||
# expected by the tiled mma
|
||||
tiled_copy_s2r_A = cute.make_tiled_copy(
|
||||
atom_copy_s2r_A,
|
||||
layout_tv=tiled_mma.tv_layout_A_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
tiled_copy_s2r_B = cute.make_tiled_copy(
|
||||
atom_copy_s2r_B,
|
||||
layout_tv=tiled_mma.tv_layout_B_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma)
|
||||
tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma)
|
||||
|
||||
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
|
||||
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
|
||||
@ -836,8 +828,7 @@ class TensorOpGemm:
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((copy_elems, 1))
|
||||
)
|
||||
tiler_mn, layout_tv = cute.make_layout_tv(thread_layout, value_layout)
|
||||
return cute.make_tiled_copy(atom_copy, layout_tv, tiler_mn)
|
||||
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
|
||||
|
||||
def raster_tile(self, i, j, f):
|
||||
new_i = i // f
|
||||
@ -845,20 +836,33 @@ class TensorOpGemm:
|
||||
return (new_i, new_j)
|
||||
|
||||
|
||||
def run_tensor_op_gemm(
|
||||
def run(
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
acc_dtype: Type[cutlass.Numeric],
|
||||
problem_shape: Tuple[int, int, int, int],
|
||||
mnkl: Tuple[int, int, int, int],
|
||||
atom_layout_mnk: Tuple[int, int, int],
|
||||
warmup_iterations: int = 2,
|
||||
iterations: int = 100,
|
||||
skip_ref_check: bool = False,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
M, N, K, L = problem_shape
|
||||
print(f"Running Ampere tensor core GEMM example:")
|
||||
print(f"mnkl: {mnkl}")
|
||||
print(
|
||||
f"A dtype: {ab_dtype}, B dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}"
|
||||
)
|
||||
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
||||
print(f"Atoms layout: {atom_layout_mnk}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {use_cold_l2}")
|
||||
M, N, K, L = mnkl
|
||||
|
||||
# Create and permute tensor A/B/C
|
||||
def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype):
|
||||
@ -866,23 +870,28 @@ def run_tensor_op_gemm(
|
||||
# else: (l, mode0, mode1) -> (mode0, mode1, l)
|
||||
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
||||
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
||||
|
||||
return (
|
||||
torch_tensor = (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
.to(dtype=dtype)
|
||||
.to(dtype=cutlass_torch.dtype(dtype))
|
||||
.permute(permute_order)
|
||||
.cuda()
|
||||
)
|
||||
# assume input is 16B aligned
|
||||
cute_tensor = (
|
||||
from_dlpack(torch_tensor, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if not is_mode0_major else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if not is_mode0_major else 0),
|
||||
stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0),
|
||||
divisibility=(128 // dtype.width),
|
||||
)
|
||||
)
|
||||
return cute_tensor, torch_tensor
|
||||
|
||||
a = create_and_permute_tensor(
|
||||
L, M, K, a_major == "m", cutlass_torch.dtype(ab_dtype)
|
||||
)
|
||||
b = create_and_permute_tensor(
|
||||
L, N, K, b_major == "n", cutlass_torch.dtype(ab_dtype)
|
||||
)
|
||||
c = create_and_permute_tensor(L, M, N, c_major == "m", cutlass_torch.dtype(c_dtype))
|
||||
ref = torch.einsum("mkl,nkl->mnl", a, b).to(cutlass_torch.dtype(c_dtype))
|
||||
mA, a_torch = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
|
||||
mB, b_torch = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
|
||||
mC, c_torch = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
|
||||
|
||||
tensor_op_gemm = TensorOpGemm(
|
||||
ab_dtype,
|
||||
@ -891,56 +900,49 @@ def run_tensor_op_gemm(
|
||||
atom_layout_mnk,
|
||||
)
|
||||
|
||||
# assume input is 16B aligned
|
||||
a_tensor = (
|
||||
from_dlpack(a, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if a_major == "k" else 0),
|
||||
stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0),
|
||||
divisibility=(128 // ab_dtype.width),
|
||||
)
|
||||
)
|
||||
b_tensor = (
|
||||
from_dlpack(b, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if b_major == "k" else 0),
|
||||
stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0),
|
||||
divisibility=(128 // ab_dtype.width),
|
||||
)
|
||||
)
|
||||
c_tensor = (
|
||||
from_dlpack(c, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if c_major == "n" else 0),
|
||||
stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0),
|
||||
divisibility=(128 // c_dtype.width),
|
||||
)
|
||||
)
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
gemm = cute.compile(tensor_op_gemm, a_tensor, b_tensor, c_tensor)
|
||||
compiled_gemm = cute.compile(tensor_op_gemm, mA, mB, mC)
|
||||
|
||||
print("Executing GEMM kernel...")
|
||||
|
||||
if not skip_ref_check:
|
||||
ref = torch.einsum(
|
||||
"mkl,nkl->mnl",
|
||||
a_torch.to(dtype=torch.float32),
|
||||
b_torch.to(dtype=torch.float32),
|
||||
).to(cutlass_torch.dtype(c_dtype))
|
||||
compiled_gemm(mA, mB, mC)
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(c_torch.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
def generate_tensors():
|
||||
a_workspace, _ = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
|
||||
b_workspace, _ = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
|
||||
c_workspace, _ = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
|
||||
return testing.JitArguments(a_workspace, b_workspace, c_workspace)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
a_torch.numel() * a_torch.element_size()
|
||||
+ b_torch.numel() * b_torch.element_size()
|
||||
+ c_torch.numel() * c_torch.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
gemm,
|
||||
kernel_arguments=testing.JitArguments(a_tensor, b_tensor, c_tensor),
|
||||
compiled_gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
warmup_iterations=warmup_iterations,
|
||||
profiling_iterations=iterations,
|
||||
iterations=iterations,
|
||||
use_cuda_graphs=False,
|
||||
)
|
||||
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
|
||||
if not skip_ref_check:
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
return avg_time_us # Return execution time in microseconds
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -985,10 +987,15 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
print("Running Ampere tensor core GEMM example:")
|
||||
run_tensor_op_gemm(
|
||||
run(
|
||||
args.a_major,
|
||||
args.b_major,
|
||||
args.c_major,
|
||||
@ -1000,5 +1007,6 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
Reference in New Issue
Block a user