Hopper Grouped GEMM support for FP8 Accum (#2123)
* Add support for fp8accum, with profiler extension * Update .gitignore * contri --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -488,6 +488,10 @@ class KernelScheduleType(enum.Enum):
|
||||
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
||||
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
||||
ImplicitTmaWarpSpecializedSm90 = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedCooperative = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
||||
|
||||
TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
@ -514,11 +518,6 @@ class KernelScheduleType(enum.Enum):
|
||||
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative = enum_auto()
|
||||
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong = enum_auto()
|
||||
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
||||
|
||||
#
|
||||
KernelScheduleTag = {
|
||||
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
|
||||
@ -551,10 +550,10 @@ KernelScheduleTag = {
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
|
||||
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
|
||||
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
|
||||
@ -598,10 +597,10 @@ KernelScheduleSuffixes = {
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
|
||||
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
|
||||
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm',
|
||||
@ -667,8 +666,8 @@ EpilogueScheduleSuffixes = {
|
||||
EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '_tma_1sm',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_tma_2sm',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma_cooperative',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma_pingpong',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma',
|
||||
}
|
||||
|
||||
class EpilogueFunctor3x(enum.Enum):
|
||||
@ -686,6 +685,15 @@ def to_grouped_schedule(schedule, grouped):
|
||||
return schedule
|
||||
|
||||
group_schedule_map = {
|
||||
# SM90
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum,
|
||||
EpilogueScheduleType.TmaWarpSpecialized : EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized,
|
||||
# SM100
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,
|
||||
|
||||
Reference in New Issue
Block a user