Shard gemm reference templates into multiple TUs for parallel compilation (#1043)
* Split apart gemm reference templates into multiple TUs for parallel compilation * remove old files * better balancing of ref kernels across TUs * remove 3 new added refcheck kernels and some un-necessary fp8 library instances to reduce lib size * remove auto fp8 kernels * remove some redundant kernels
This commit is contained in:
@ -4105,24 +4105,18 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions_medium = [
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions_large = [
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
||||
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
||||
@ -4264,7 +4258,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
||||
DataType.tf32, DataType.tf32, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add)
|
||||
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 90
|
||||
|
||||
@ -4277,8 +4271,6 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
|
||||
tile_descriptions_medium = [
|
||||
@ -4286,17 +4278,13 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
|
||||
|
||||
tile_descriptions_small = [
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1])
|
||||
]
|
||||
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
|
||||
|
||||
@ -4341,7 +4329,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||
])
|
||||
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn_nt, tile_descriptions_medium, data_types, [
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||
@ -4367,7 +4355,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
||||
])
|
||||
else:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn_nt, tile_descriptions, data_types, schedules_default)
|
||||
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_types, schedules_transposed_epilogue)
|
||||
|
||||
#
|
||||
@ -4402,16 +4390,12 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions_medium = [
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
|
||||
|
||||
@ -4607,8 +4591,6 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions = [
|
||||
# 128x128x128
|
||||
@ -4616,10 +4598,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
|
||||
elif math_inst.instruction_shape[1] == 64:
|
||||
tile_descriptions = [
|
||||
# 256x64x128
|
||||
@ -4627,33 +4606,31 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
|
||||
else:
|
||||
assert False, "math inst is not supported"
|
||||
|
||||
# some schedules disabled to save on library size
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
schedules = [
|
||||
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||
# [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||
# [KernelScheduleType.TmaWarpSpecializedFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||
]
|
||||
stream_k_schedules = [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]]
|
||||
else:
|
||||
schedules = [
|
||||
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
|
||||
# [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
|
||||
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||
# TmaWarpSpecializedCooperative require CUDA version >= 12.1 for optimal performance.
|
||||
]
|
||||
stream_k_schedules = []
|
||||
|
||||
|
||||
for data_type in data_types:
|
||||
# With No-SMEM epilogues
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules)
|
||||
@ -4661,8 +4638,8 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
# Persistent kernels with TMA epilogues
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
|
||||
|
||||
# Small tiles
|
||||
@ -4673,7 +4650,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
# Add stream-K variants (with and without TMA epilogues)
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK])
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]],
|
||||
tile_schedulers=[TileSchedulerType.StreamK])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user