Shard gemm reference templates into multiple TUs for parallel compilation (#1043)

* Split apart gemm reference templates into multiple TUs for parallel compilation

* remove old files

* better balancing of ref kernels across TUs

* remove 3 new added refcheck kernels and some un-necessary fp8 library instances to reduce lib size

* remove auto fp8 kernels

* remove some redundant kernels
This commit is contained in:
Vijay Thakkar
2023-08-30 16:46:30 -04:00
committed by GitHub
parent 34fd98056b
commit e01b9b5029
18 changed files with 1498 additions and 824 deletions

View File

@ -4105,24 +4105,18 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_medium = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_large = [
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
@ -4264,7 +4258,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
DataType.tf32, DataType.tf32, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add)
min_cc = 90
max_cc = 90
@ -4277,8 +4271,6 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4],
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4],
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_medium = [
@ -4286,17 +4278,13 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_small = [
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1])
]
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
@ -4341,7 +4329,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized]
])
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn_nt, tile_descriptions_medium, data_types, [
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized]
@ -4367,7 +4355,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
])
else:
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn_nt, tile_descriptions, data_types, schedules_default)
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_types, schedules_transposed_epilogue)
#
@ -4402,16 +4390,12 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions_medium = [
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
@ -4607,8 +4591,6 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
tile_descriptions = [
# 128x128x128
@ -4616,10 +4598,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
elif math_inst.instruction_shape[1] == 64:
tile_descriptions = [
# 256x64x128
@ -4627,33 +4606,31 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
]
else:
assert False, "math inst is not supported"
# some schedules disabled to save on library size
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
schedules = [
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized],
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized],
# [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]
# [KernelScheduleType.TmaWarpSpecializedFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]
]
stream_k_schedules = [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]]
else:
schedules = [
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
# [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized]
# TmaWarpSpecializedCooperative require CUDA version >= 12.1 for optimal performance.
]
stream_k_schedules = []
for data_type in data_types:
# With No-SMEM epilogues
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules)
@ -4661,8 +4638,8 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
# Persistent kernels with TMA epilogues
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.TmaWarpSpecialized],
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.TmaWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
# Small tiles
@ -4673,7 +4650,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
# Add stream-K variants (with and without TMA epilogues)
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK])
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]],
tile_schedulers=[TileSchedulerType.StreamK])