v4.3 update. (#2709)
* v4.3 update. * Update the cute_dsl_api changelog's doc link * Update version to 4.3.0 * Update the example link * Update doc to encourage user to install DSL from requirements.txt --------- Co-authored-by: Larry Wu <larwu@nvidia.com>
This commit is contained in:
@ -322,7 +322,7 @@ def is_complex(data_type):
|
||||
return False
|
||||
|
||||
def is_block_scaled(gemm_kind):
|
||||
return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x)
|
||||
return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x, GemmKind.BlockScaledSparseUniversal3x)
|
||||
|
||||
def is_blockwise(gemm_kind):
|
||||
return gemm_kind in (GemmKind.BlockwiseUniversal3x, GemmKind.GroupedBlockwiseUniversal3x)
|
||||
@ -548,6 +548,12 @@ class KernelScheduleType(enum.Enum):
|
||||
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
SparseMxf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
SparseMxf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
SparseNvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
SparseNvf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
SparseMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
SparseMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
# FP4 Ultra
|
||||
MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto()
|
||||
MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto()
|
||||
@ -586,6 +592,10 @@ class KernelScheduleType(enum.Enum):
|
||||
Mxf4TmaWarpSpecializedCooperativeSm120 = enum_auto()
|
||||
Mxf4TmaWarpSpecializedPingpongSm120 = enum_auto()
|
||||
|
||||
SparseMxf8f6f4TmaWarpSpecializedSm120 = enum_auto()
|
||||
SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120 = enum_auto()
|
||||
SparseNvf4TmaWarpSpecializedSm120 = enum_auto()
|
||||
SparseMxf4TmaWarpSpecializedSm120 = enum_auto()
|
||||
F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto()
|
||||
|
||||
BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto()
|
||||
@ -637,6 +647,13 @@ KernelScheduleTag = {
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
|
||||
|
||||
KernelScheduleType.SparseMxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf4Sm100',
|
||||
KernelScheduleType.SparseMxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf4Sm100',
|
||||
KernelScheduleType.SparseNvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmNvf4Sm100',
|
||||
KernelScheduleType.SparseNvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmNvf4Sm100',
|
||||
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100',
|
||||
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100',
|
||||
|
||||
# FP4 Ultra
|
||||
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103',
|
||||
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103',
|
||||
@ -694,6 +711,10 @@ KernelScheduleTag = {
|
||||
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120',
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120',
|
||||
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Sm120',
|
||||
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120',
|
||||
KernelScheduleType.SparseNvf4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedNvf4Sm120',
|
||||
KernelScheduleType.SparseMxf4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf4Sm120',
|
||||
}
|
||||
|
||||
#
|
||||
@ -742,6 +763,14 @@ KernelScheduleSuffixes = {
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
|
||||
|
||||
KernelScheduleType.SparseMxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
|
||||
KernelScheduleType.SparseMxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
|
||||
KernelScheduleType.SparseNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
||||
KernelScheduleType.SparseNvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
|
||||
|
||||
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm',
|
||||
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm',
|
||||
|
||||
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm',
|
||||
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm',
|
||||
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm',
|
||||
@ -796,6 +825,11 @@ KernelScheduleSuffixes = {
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs32',
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs32',
|
||||
|
||||
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedSm120: '_q',
|
||||
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120: '_acc2x4_q',
|
||||
KernelScheduleType.SparseNvf4TmaWarpSpecializedSm120: '_o_vs16',
|
||||
KernelScheduleType.SparseMxf4TmaWarpSpecializedSm120: '_o_vs32',
|
||||
|
||||
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q',
|
||||
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: '_cooperative_q',
|
||||
@ -827,6 +861,13 @@ class EpilogueScheduleType(enum.Enum):
|
||||
PtrArrayTmaWarpSpecialized2Sm = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedCooperative = enum_auto()
|
||||
TmaWarpSpecialized1SmNvf4 = enum_auto()
|
||||
TmaWarpSpecialized2SmNvf4 = enum_auto()
|
||||
TmaWarpSpecialized1SmMxf4 = enum_auto()
|
||||
TmaWarpSpecialized2SmMxf4 = enum_auto()
|
||||
TmaWarpSpecialized1SmMxf8f6f4 = enum_auto()
|
||||
TmaWarpSpecialized2SmMxf8f6f4 = enum_auto()
|
||||
SparseTmaWarpSpecializedCooperativeSm120 = enum_auto()
|
||||
|
||||
#
|
||||
EpilogueScheduleTag = {
|
||||
@ -854,6 +895,13 @@ EpilogueScheduleTag = {
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1SmNvf4: 'cutlass::epilogue::TmaWarpSpecialized1SmNvf4',
|
||||
EpilogueScheduleType.TmaWarpSpecialized2SmNvf4: 'cutlass::epilogue::TmaWarpSpecialized2SmNvf4',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1SmMxf4: 'cutlass::epilogue::TmaWarpSpecialized1SmMxf4',
|
||||
EpilogueScheduleType.TmaWarpSpecialized2SmMxf4: 'cutlass::epilogue::TmaWarpSpecialized2SmMxf4',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1SmMxf8f6f4: 'cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4',
|
||||
EpilogueScheduleType.TmaWarpSpecialized2SmMxf8f6f4: 'cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4',
|
||||
EpilogueScheduleType.SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::epilogue::SparseTmaWarpSpecializedCooperativeSm120',
|
||||
}
|
||||
|
||||
#
|
||||
@ -882,6 +930,13 @@ EpilogueScheduleSuffixes = {
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_epi_tma',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma',
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1SmNvf4: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecialized2SmNvf4: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1SmMxf4: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecialized2SmMxf4: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1SmMxf8f6f4: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecialized2SmMxf8f6f4: '_epi_tma',
|
||||
EpilogueScheduleType.SparseTmaWarpSpecializedCooperativeSm120: '_epi_tma',
|
||||
}
|
||||
|
||||
class EpilogueFunctor3x(enum.Enum):
|
||||
@ -906,6 +961,12 @@ def is_tma_epilogue(epilogue_schedule_type):
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||
EpilogueScheduleType.TmaWarpSpecialized1SmNvf4,
|
||||
EpilogueScheduleType.TmaWarpSpecialized2SmNvf4,
|
||||
EpilogueScheduleType.TmaWarpSpecialized1SmMxf4,
|
||||
EpilogueScheduleType.TmaWarpSpecialized2SmMxf4,
|
||||
EpilogueScheduleType.TmaWarpSpecialized1SmMxf8f6f4,
|
||||
EpilogueScheduleType.TmaWarpSpecialized2SmMxf8f6f4,
|
||||
]
|
||||
|
||||
def to_grouped_schedule(schedule, grouped):
|
||||
@ -1040,7 +1101,8 @@ class OpcodeClass(enum.Enum):
|
||||
TensorOp = enum_auto()
|
||||
WmmaTensorOp = enum_auto()
|
||||
SparseTensorOp = enum_auto()
|
||||
BlockScaledTensorOp = enum_auto()
|
||||
BlockScaledTensorOp = enum_auto()
|
||||
BlockScaledSparseTensorOp = enum_auto()
|
||||
|
||||
|
||||
OpcodeClassNames = {
|
||||
@ -1048,7 +1110,8 @@ OpcodeClassNames = {
|
||||
OpcodeClass.TensorOp: 'tensorop',
|
||||
OpcodeClass.WmmaTensorOp: 'wmma_tensorop',
|
||||
OpcodeClass.SparseTensorOp: 'sptensorop',
|
||||
OpcodeClass.BlockScaledTensorOp: 'bstensorop'
|
||||
OpcodeClass.BlockScaledTensorOp: 'bstensorop',
|
||||
OpcodeClass.BlockScaledSparseTensorOp: 'bssptensorop'
|
||||
}
|
||||
|
||||
OpcodeClassTag = {
|
||||
@ -1056,7 +1119,8 @@ OpcodeClassTag = {
|
||||
OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp',
|
||||
OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp',
|
||||
OpcodeClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp',
|
||||
OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp'
|
||||
OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp',
|
||||
OpcodeClass.BlockScaledSparseTensorOp: 'cutlass::arch::OpClassBlockScaledSparseTensorOp'
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
@ -1143,6 +1207,7 @@ class GemmKind(enum.Enum):
|
||||
GroupedBlockScaledUniversal3x = enum_auto()
|
||||
BlockwiseUniversal3x = enum_auto()
|
||||
GroupedBlockwiseUniversal3x = enum_auto()
|
||||
BlockScaledSparseUniversal3x = enum_auto()
|
||||
|
||||
#
|
||||
GemmKindNames = {
|
||||
@ -1158,7 +1223,8 @@ GemmKindNames = {
|
||||
GemmKind.GroupedUniversal3x: "gemm_grouped",
|
||||
GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped",
|
||||
GemmKind.BlockwiseUniversal3x: "gemm",
|
||||
GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped"
|
||||
GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped",
|
||||
GemmKind.BlockScaledSparseUniversal3x: "spgemm"
|
||||
}
|
||||
|
||||
#
|
||||
|
||||
Reference in New Issue
Block a user