v4.0 update. (#2371)
This commit is contained in:
@ -345,6 +345,15 @@ def get_real_from_complex(complex_type):
|
||||
return r
|
||||
return DataType.invalid
|
||||
|
||||
# TMA requires an alignment of 128 bits for all data types
|
||||
def get_tma_alignment(data_type):
|
||||
if data_type == DataType.void:
|
||||
return 0
|
||||
elif DataTypeSize[data_type] == 6:
|
||||
return 128 # 96B alignment for 16U6 format
|
||||
else:
|
||||
return 128 // DataTypeSize[data_type]
|
||||
|
||||
#
|
||||
class ComplexMultiplyOp(enum.Enum):
|
||||
multiply_add = enum_auto()
|
||||
@ -546,6 +555,9 @@ class KernelScheduleType(enum.Enum):
|
||||
|
||||
F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto()
|
||||
|
||||
BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto()
|
||||
BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto()
|
||||
|
||||
KernelScheduleTag = {
|
||||
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
|
||||
KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage',
|
||||
@ -614,7 +626,10 @@ KernelScheduleTag = {
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120',
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf4Sm120',
|
||||
|
||||
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelScheduleSparseF8f6f4Sm120'
|
||||
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelScheduleSparseF8f6f4Sm120',
|
||||
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120',
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120',
|
||||
}
|
||||
|
||||
#
|
||||
@ -685,7 +700,10 @@ KernelScheduleSuffixes = {
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs32',
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs32',
|
||||
|
||||
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q'
|
||||
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q',
|
||||
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: '_cooperative_q',
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: '_pingpong_q'
|
||||
}
|
||||
|
||||
class EpilogueScheduleType(enum.Enum):
|
||||
@ -756,6 +774,20 @@ EpilogueFunctor3xTag = {
|
||||
EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor',
|
||||
}
|
||||
|
||||
# TMA epilogues have certain alignment requirements as calculated in get_tma_alignment(data_type)
|
||||
def is_tma_epilogue(epilogue_schedule_type):
|
||||
return epilogue_schedule_type in [
|
||||
EpilogueScheduleType.ScheduleAuto,
|
||||
EpilogueScheduleType.TmaWarpSpecialized,
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.TmaWarpSpecialized1Sm,
|
||||
EpilogueScheduleType.TmaWarpSpecialized2Sm,
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||
]
|
||||
|
||||
def to_grouped_schedule(schedule, grouped):
|
||||
if not grouped:
|
||||
return schedule
|
||||
@ -771,17 +803,18 @@ def to_grouped_schedule(schedule, grouped):
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized,
|
||||
# SM100
|
||||
KernelScheduleType.TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100,
|
||||
KernelScheduleType.TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100,
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100,
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100,
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100,
|
||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
|
||||
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100,
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100,
|
||||
|
||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
|
||||
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
|
||||
}
|
||||
|
||||
return group_schedule_map[schedule]
|
||||
|
||||
Reference in New Issue
Block a user