v4.1 release update v2. (#2481)
This commit is contained in:
@ -38,6 +38,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
@ -226,7 +227,7 @@ class PersistentDenseGemmKernel:
|
||||
self.cta_sync_bar_id = 0
|
||||
self.epilog_sync_bar_id = 1
|
||||
self.tmem_ptr_sync_bar_id = 2
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
def _setup_attributes(self):
|
||||
"""Set up configurations that are dependent on GEMM inputs
|
||||
@ -1308,11 +1309,7 @@ class PersistentDenseGemmKernel:
|
||||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy(
|
||||
copy_atom_r2s,
|
||||
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_copy_t2r.tiler_mn,
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
@ -1824,7 +1821,7 @@ class PersistentDenseGemmKernel:
|
||||
return can_implement
|
||||
|
||||
|
||||
def run_dense_gemm(
|
||||
def run(
|
||||
mnkl: Tuple[int, int, int, int],
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
@ -1832,17 +1829,58 @@ def run_dense_gemm(
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
mma_tiler_mn: Tuple[int, int],
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
use_2cta_instrs: bool,
|
||||
use_tma_store: bool,
|
||||
tolerance: float,
|
||||
mma_tiler_mn: Tuple[int, int] = (256, 256),
|
||||
cluster_shape_mn: Tuple[int, int] = (2, 1),
|
||||
use_2cta_instrs: bool = True,
|
||||
use_tma_store: bool = True,
|
||||
tolerance: float = 1e-01,
|
||||
warmup_iterations: int = 0,
|
||||
iterations: int = 1,
|
||||
skip_ref_check: bool = False,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
|
||||
"""Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking.
|
||||
|
||||
This function prepares input tensors, configures and launches the persistent GEMM kernel,
|
||||
optionally performs reference validation, and benchmarks the execution performance.
|
||||
|
||||
:param mnkl: Problem size (M, N, K, L)
|
||||
:type mnkl: Tuple[int, int, int, int]
|
||||
:param ab_dtype: Data type for input tensors A and B
|
||||
:type ab_dtype: Type[cutlass.Numeric]
|
||||
:param c_dtype: Data type for output tensor C
|
||||
:type c_dtype: Type[cutlass.Numeric]
|
||||
:param acc_dtype: Data type for accumulation during matrix multiplication
|
||||
:type acc_dtype: Type[cutlass.Numeric]
|
||||
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
|
||||
:type a_major/b_major/c_major: str
|
||||
:param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the
|
||||
default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type mma_tiler_mn: Tuple[int, int], optional
|
||||
:param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
|
||||
default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type cluster_shape_mn: Tuple[int, int], optional
|
||||
:param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
|
||||
will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type use_2cta_instrs: bool, optional
|
||||
:param use_tma_store: Whether to use TMA store. If not specified in the decorator parameters, the autotuner will use
|
||||
the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type use_tma_store: bool, optional
|
||||
:param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
|
||||
:type tolerance: float, optional
|
||||
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
||||
:type warmup_iterations: int, optional
|
||||
:param iterations: Number of benchmark iterations to run, defaults to 1
|
||||
:type iterations: int, optional
|
||||
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
||||
:type skip_ref_check: bool, optional
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||||
:type use_cold_l2: bool, optional
|
||||
:raises RuntimeError: If CUDA GPU is not available
|
||||
:raises ValueError: If the configuration is invalid or unsupported by the kernel
|
||||
:return: Execution time of the GEMM kernel
|
||||
:rtype: float
|
||||
"""
|
||||
print(f"Running Blackwell Persistent Dense GEMM test with:")
|
||||
print(f"mnkl: {mnkl}")
|
||||
@ -1855,6 +1893,7 @@ def run_dense_gemm(
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||||
|
||||
# Unpack parameters
|
||||
m, n, k, l = mnkl
|
||||
@ -1931,15 +1970,15 @@ def run_dense_gemm(
|
||||
is_dynamic_layout=is_dynamic_layout,
|
||||
)
|
||||
|
||||
return f32_torch_tensor, cute_tensor, torch_tensor
|
||||
return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu
|
||||
|
||||
a_ref, a_tensor, a_torch = create_and_permute_tensor(
|
||||
a_ref, a_tensor, a_torch, a_torch_cpu = create_and_permute_tensor(
|
||||
l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True
|
||||
)
|
||||
b_ref, b_tensor, b_torch = create_and_permute_tensor(
|
||||
b_ref, b_tensor, b_torch, b_torch_cpu = create_and_permute_tensor(
|
||||
l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True
|
||||
)
|
||||
c_ref, c_tensor, c_torch = create_and_permute_tensor(
|
||||
c_ref, c_tensor, c_torch, c_torch_cpu = create_and_permute_tensor(
|
||||
l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True
|
||||
)
|
||||
|
||||
@ -1967,16 +2006,8 @@ def run_dense_gemm(
|
||||
gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, current_stream
|
||||
)
|
||||
|
||||
# Launch GPU kernel
|
||||
# Warm up
|
||||
for i in range(warmup_iterations):
|
||||
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
# Execution
|
||||
for i in range(iterations):
|
||||
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
|
||||
# Compute reference result
|
||||
if not skip_ref_check:
|
||||
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
if ab_dtype in {
|
||||
cutlass.Int8,
|
||||
cutlass.Uint8,
|
||||
@ -2028,6 +2059,40 @@ def run_dense_gemm(
|
||||
rtol=1e-05,
|
||||
)
|
||||
|
||||
def generate_tensors():
|
||||
a_tensor, _ = cutlass_torch.cute_tensor_like(
|
||||
a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
|
||||
)
|
||||
b_tensor, _ = cutlass_torch.cute_tensor_like(
|
||||
b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
|
||||
)
|
||||
c_tensor, _ = cutlass_torch.cute_tensor_like(
|
||||
c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16
|
||||
)
|
||||
return testing.JitArguments(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
a_torch_cpu.numel() * a_torch_cpu.element_size()
|
||||
+ b_torch_cpu.numel() * b_torch_cpu.element_size()
|
||||
+ c_torch_cpu.numel() * c_torch_cpu.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -2090,6 +2155,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -2102,7 +2173,7 @@ if __name__ == "__main__":
|
||||
if len(args.cluster_shape_mn) != 2:
|
||||
parser.error("--cluster_shape_mn must contain exactly 2 values")
|
||||
|
||||
run_dense_gemm(
|
||||
run(
|
||||
args.mnkl,
|
||||
args.ab_dtype,
|
||||
args.c_dtype,
|
||||
@ -2118,5 +2189,6 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
Reference in New Issue
Block a user