Hopper Grouped GEMM support for FP8 Accum (#2123)
* Add support for fp8accum, with profiler extension * Update .gitignore * contri --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -488,6 +488,10 @@ class KernelScheduleType(enum.Enum):
|
||||
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
||||
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
||||
ImplicitTmaWarpSpecializedSm90 = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedCooperative = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
||||
|
||||
TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
@ -514,11 +518,6 @@ class KernelScheduleType(enum.Enum):
|
||||
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative = enum_auto()
|
||||
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong = enum_auto()
|
||||
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
||||
|
||||
#
|
||||
KernelScheduleTag = {
|
||||
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
|
||||
@ -551,10 +550,10 @@ KernelScheduleTag = {
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
|
||||
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
|
||||
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
|
||||
@ -598,10 +597,10 @@ KernelScheduleSuffixes = {
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
|
||||
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
|
||||
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm',
|
||||
@ -667,8 +666,8 @@ EpilogueScheduleSuffixes = {
|
||||
EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '_tma_1sm',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_tma_2sm',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma_cooperative',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma_pingpong',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma',
|
||||
}
|
||||
|
||||
class EpilogueFunctor3x(enum.Enum):
|
||||
@ -686,6 +685,15 @@ def to_grouped_schedule(schedule, grouped):
|
||||
return schedule
|
||||
|
||||
group_schedule_map = {
|
||||
# SM90
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum,
|
||||
EpilogueScheduleType.TmaWarpSpecialized : EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized,
|
||||
# SM100
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,
|
||||
|
||||
@ -494,8 +494,6 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
# the following cases are unsupported by grouped GEMM
|
||||
if not is_aligned:
|
||||
return [], []
|
||||
if not can_do_tma_epilogue:
|
||||
return [], []
|
||||
if requires_transposed_epilogue:
|
||||
return [], []
|
||||
|
||||
@ -513,16 +511,15 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
return [], []
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
|
||||
schedules = []
|
||||
if not grouped:
|
||||
schedules.append(
|
||||
[
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
])
|
||||
schedules.append(
|
||||
[
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum if not grouped else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
schedules.append(
|
||||
[
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
return schedules, []
|
||||
return [], []
|
||||
@ -586,18 +583,9 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
|
||||
return schedules, stream_k_schedules
|
||||
|
||||
if grouped:
|
||||
pingpong = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum
|
||||
cooperative = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum
|
||||
if can_do_tma_epilogue:
|
||||
schedules.append([pingpong, EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong])
|
||||
if can_do_cooperative:
|
||||
schedules.append([cooperative, EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative])
|
||||
return schedules, []
|
||||
|
||||
schedules = []
|
||||
# Pruning: emit Void-C kernels with persistent kernels only
|
||||
if level >= 1 or not is_void_c:
|
||||
# Pruning: emit Void-C and Grouped kernels with persistent kernels only
|
||||
if (level >= 1 or not is_void_c) and not grouped:
|
||||
# Pruning: don't stamp out fp8 kernels with auto schedule
|
||||
if not is_fp8:
|
||||
schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue])
|
||||
@ -610,28 +598,29 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
|
||||
if not is_fp8 or level >= 1:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong,
|
||||
EpilogueScheduleType.TmaWarpSpecialized
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
|
||||
])
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum,
|
||||
EpilogueScheduleType.TmaWarpSpecialized
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
|
||||
])
|
||||
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
# Pruning: don't stamp out fp8 ping-ponging kernel with non-tma epilogue
|
||||
# Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue
|
||||
if not is_fp8 or level >= 1:
|
||||
schedules.append([KernelScheduleType.TmaWarpSpecializedPingpong, default_epilogue])
|
||||
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)])
|
||||
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue])
|
||||
schedules.append([KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, default_epilogue])
|
||||
if not grouped:
|
||||
schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue])
|
||||
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)])
|
||||
|
||||
if can_do_cooperative:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
default_epilogue
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(default_epilogue, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
@ -639,8 +628,8 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
])
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
default_epilogue
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||
to_grouped_schedule(default_epilogue, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
@ -652,8 +641,8 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
assert not requires_transposed_epilogue
|
||||
if can_do_cooperative:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
@ -661,14 +650,16 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
])
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||
])
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
])
|
||||
|
||||
# Grouped GEMM do not support Stream-K scheduler
|
||||
if grouped:
|
||||
return schedules, []
|
||||
return schedules, stream_k_schedules
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user