Hopper Grouped GEMM support for FP8 Accum (#2123)

* Add support for fp8accum, with profiler extension

* Update .gitignore

* contri

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
ANIKET SHIVAM
2025-02-20 18:55:26 -08:00
committed by GitHub
parent b84e9802d8
commit 9b3772dfa6
9 changed files with 914 additions and 71 deletions

View File

@ -488,6 +488,10 @@ class KernelScheduleType(enum.Enum):
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
ImplicitTmaWarpSpecializedSm90 = enum_auto()
PtrArrayTmaWarpSpecializedCooperative = enum_auto()
PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
TmaWarpSpecialized1SmSm100 = enum_auto()
TmaWarpSpecialized2SmSm100 = enum_auto()
@ -514,11 +518,6 @@ class KernelScheduleType(enum.Enum):
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
KernelPtrArrayTmaWarpSpecializedCooperative = enum_auto()
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
KernelPtrArrayTmaWarpSpecializedPingpong = enum_auto()
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
#
KernelScheduleTag = {
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
@ -551,10 +550,10 @@ KernelScheduleTag = {
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
@ -598,10 +597,10 @@ KernelScheduleSuffixes = {
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm',
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm',
@ -667,8 +666,8 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma',
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '_tma_1sm',
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_tma_2sm',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma_cooperative',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma_pingpong',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma',
}
class EpilogueFunctor3x(enum.Enum):
@ -686,6 +685,15 @@ def to_grouped_schedule(schedule, grouped):
return schedule
group_schedule_map = {
# SM90
KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecialized : EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized,
# SM100
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,

View File

@ -494,8 +494,6 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
# the following cases are unsupported by grouped GEMM
if not is_aligned:
return [], []
if not can_do_tma_epilogue:
return [], []
if requires_transposed_epilogue:
return [], []
@ -513,16 +511,15 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
return [], []
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
schedules = []
if not grouped:
schedules.append(
[
KernelScheduleType.TmaWarpSpecializedCooperative,
EpilogueScheduleType.TmaWarpSpecializedCooperative
])
schedules.append(
[
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum if not grouped else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecializedCooperative if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
schedules.append(
[
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
return schedules, []
return [], []
@ -586,18 +583,9 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
return schedules, stream_k_schedules
if grouped:
pingpong = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum
cooperative = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum
if can_do_tma_epilogue:
schedules.append([pingpong, EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong])
if can_do_cooperative:
schedules.append([cooperative, EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative])
return schedules, []
schedules = []
# Pruning: emit Void-C kernels with persistent kernels only
if level >= 1 or not is_void_c:
# Pruning: emit Void-C and Grouped kernels with persistent kernels only
if (level >= 1 or not is_void_c) and not grouped:
# Pruning: don't stamp out fp8 kernels with auto schedule
if not is_fp8:
schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue])
@ -610,28 +598,29 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
if not is_fp8 or level >= 1:
schedules.append([
KernelScheduleType.TmaWarpSpecializedPingpong,
EpilogueScheduleType.TmaWarpSpecialized
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
])
if can_do_fp8_fast_accum:
schedules.append([
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecialized
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
])
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
# Pruning: don't stamp out fp8 ping-ponging kernel with non-tma epilogue
# Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue
if not is_fp8 or level >= 1:
schedules.append([KernelScheduleType.TmaWarpSpecializedPingpong, default_epilogue])
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)])
if can_do_fp8_fast_accum:
schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue])
schedules.append([KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, default_epilogue])
if not grouped:
schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue])
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)])
if can_do_cooperative:
schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
default_epilogue
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(default_epilogue, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
@ -639,8 +628,8 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
])
if can_do_fp8_fast_accum:
schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
default_epilogue
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
to_grouped_schedule(default_epilogue, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
@ -652,8 +641,8 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
assert not requires_transposed_epilogue
if can_do_cooperative:
schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
EpilogueScheduleType.TmaWarpSpecializedCooperative
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
@ -661,14 +650,16 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
])
if can_do_fp8_fast_accum:
schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecializedCooperative
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecializedCooperative
])
# Grouped GEMM do not support Stream-K scheduler
if grouped:
return schedules, []
return schedules, stream_k_schedules