v4.0 update. (#2371)

This commit is contained in:
Junkai-Wu
2025-06-06 14:39:20 +08:00
committed by GitHub
parent 2e2af190bd
commit 8bdbfca682
254 changed files with 29751 additions and 1980 deletions

View File

@ -345,6 +345,15 @@ def get_real_from_complex(complex_type):
return r
return DataType.invalid
# TMA requires an alignment of 128 bits for all data types
def get_tma_alignment(data_type):
if data_type == DataType.void:
return 0
elif DataTypeSize[data_type] == 6:
return 128 # 96B alignment for 16U6 format
else:
return 128 // DataTypeSize[data_type]
#
class ComplexMultiplyOp(enum.Enum):
multiply_add = enum_auto()
@ -546,6 +555,9 @@ class KernelScheduleType(enum.Enum):
F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto()
BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto()
BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto()
KernelScheduleTag = {
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage',
@ -614,7 +626,10 @@ KernelScheduleTag = {
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120',
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf4Sm120',
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelScheduleSparseF8f6f4Sm120'
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelScheduleSparseF8f6f4Sm120',
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120',
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120',
}
#
@ -685,7 +700,10 @@ KernelScheduleSuffixes = {
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs32',
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs32',
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q'
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q',
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: '_cooperative_q',
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: '_pingpong_q'
}
class EpilogueScheduleType(enum.Enum):
@ -756,6 +774,20 @@ EpilogueFunctor3xTag = {
EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor',
}
# TMA epilogues have certain alignment requirements as calculated in get_tma_alignment(data_type)
def is_tma_epilogue(epilogue_schedule_type):
return epilogue_schedule_type in [
EpilogueScheduleType.ScheduleAuto,
EpilogueScheduleType.TmaWarpSpecialized,
EpilogueScheduleType.TmaWarpSpecializedCooperative,
EpilogueScheduleType.TmaWarpSpecialized1Sm,
EpilogueScheduleType.TmaWarpSpecialized2Sm,
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
]
def to_grouped_schedule(schedule, grouped):
if not grouped:
return schedule
@ -771,17 +803,18 @@ def to_grouped_schedule(schedule, grouped):
EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized,
# SM100
KernelScheduleType.TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100,
KernelScheduleType.TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100,
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100,
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100,
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100,
EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100,
KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100,
EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
}
return group_schedule_map[schedule]