v4.3 update. (#2709)

* v4.3 update.

* Update the cute_dsl_api changelog's doc link

* Update version to 4.3.0

* Update the example link

* Update doc to encourage user to install DSL from requirements.txt

---------

Co-authored-by: Larry Wu <larwu@nvidia.com>
This commit is contained in:
Junkai-Wu
2025-10-22 02:26:30 +08:00
committed by GitHub
parent e6e2cc29f5
commit b1d6e2c9b3
244 changed files with 59272 additions and 10455 deletions

View File

@ -322,7 +322,7 @@ def is_complex(data_type):
return False
def is_block_scaled(gemm_kind):
return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x)
return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x, GemmKind.BlockScaledSparseUniversal3x)
def is_blockwise(gemm_kind):
return gemm_kind in (GemmKind.BlockwiseUniversal3x, GemmKind.GroupedBlockwiseUniversal3x)
@ -548,6 +548,12 @@ class KernelScheduleType(enum.Enum):
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
SparseMxf4TmaWarpSpecialized1SmSm100 = enum_auto()
SparseMxf4TmaWarpSpecialized2SmSm100 = enum_auto()
SparseNvf4TmaWarpSpecialized1SmSm100 = enum_auto()
SparseNvf4TmaWarpSpecialized2SmSm100 = enum_auto()
SparseMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
SparseMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
# FP4 Ultra
MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto()
MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto()
@ -586,6 +592,10 @@ class KernelScheduleType(enum.Enum):
Mxf4TmaWarpSpecializedCooperativeSm120 = enum_auto()
Mxf4TmaWarpSpecializedPingpongSm120 = enum_auto()
SparseMxf8f6f4TmaWarpSpecializedSm120 = enum_auto()
SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120 = enum_auto()
SparseNvf4TmaWarpSpecializedSm120 = enum_auto()
SparseMxf4TmaWarpSpecializedSm120 = enum_auto()
F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto()
BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto()
@ -637,6 +647,13 @@ KernelScheduleTag = {
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
KernelScheduleType.SparseMxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf4Sm100',
KernelScheduleType.SparseMxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf4Sm100',
KernelScheduleType.SparseNvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmNvf4Sm100',
KernelScheduleType.SparseNvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmNvf4Sm100',
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100',
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100',
# FP4 Ultra
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103',
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103',
@ -694,6 +711,10 @@ KernelScheduleTag = {
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120',
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120',
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Sm120',
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120',
KernelScheduleType.SparseNvf4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedNvf4Sm120',
KernelScheduleType.SparseMxf4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf4Sm120',
}
#
@ -742,6 +763,14 @@ KernelScheduleSuffixes = {
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
KernelScheduleType.SparseMxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
KernelScheduleType.SparseMxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
KernelScheduleType.SparseNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
KernelScheduleType.SparseNvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm',
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm',
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm',
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm',
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm',
@ -796,6 +825,11 @@ KernelScheduleSuffixes = {
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs32',
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs32',
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedSm120: '_q',
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120: '_acc2x4_q',
KernelScheduleType.SparseNvf4TmaWarpSpecializedSm120: '_o_vs16',
KernelScheduleType.SparseMxf4TmaWarpSpecializedSm120: '_o_vs32',
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q',
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: '_cooperative_q',
@ -827,6 +861,13 @@ class EpilogueScheduleType(enum.Enum):
PtrArrayTmaWarpSpecialized2Sm = enum_auto()
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
PtrArrayTmaWarpSpecializedCooperative = enum_auto()
TmaWarpSpecialized1SmNvf4 = enum_auto()
TmaWarpSpecialized2SmNvf4 = enum_auto()
TmaWarpSpecialized1SmMxf4 = enum_auto()
TmaWarpSpecialized2SmMxf4 = enum_auto()
TmaWarpSpecialized1SmMxf8f6f4 = enum_auto()
TmaWarpSpecialized2SmMxf8f6f4 = enum_auto()
SparseTmaWarpSpecializedCooperativeSm120 = enum_auto()
#
EpilogueScheduleTag = {
@ -854,6 +895,13 @@ EpilogueScheduleTag = {
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong',
EpilogueScheduleType.TmaWarpSpecialized1SmNvf4: 'cutlass::epilogue::TmaWarpSpecialized1SmNvf4',
EpilogueScheduleType.TmaWarpSpecialized2SmNvf4: 'cutlass::epilogue::TmaWarpSpecialized2SmNvf4',
EpilogueScheduleType.TmaWarpSpecialized1SmMxf4: 'cutlass::epilogue::TmaWarpSpecialized1SmMxf4',
EpilogueScheduleType.TmaWarpSpecialized2SmMxf4: 'cutlass::epilogue::TmaWarpSpecialized2SmMxf4',
EpilogueScheduleType.TmaWarpSpecialized1SmMxf8f6f4: 'cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4',
EpilogueScheduleType.TmaWarpSpecialized2SmMxf8f6f4: 'cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4',
EpilogueScheduleType.SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::epilogue::SparseTmaWarpSpecializedCooperativeSm120',
}
#
@ -882,6 +930,13 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_epi_tma',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized1SmNvf4: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized2SmNvf4: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized1SmMxf4: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized2SmMxf4: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized1SmMxf8f6f4: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized2SmMxf8f6f4: '_epi_tma',
EpilogueScheduleType.SparseTmaWarpSpecializedCooperativeSm120: '_epi_tma',
}
class EpilogueFunctor3x(enum.Enum):
@ -906,6 +961,12 @@ def is_tma_epilogue(epilogue_schedule_type):
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
EpilogueScheduleType.TmaWarpSpecialized1SmNvf4,
EpilogueScheduleType.TmaWarpSpecialized2SmNvf4,
EpilogueScheduleType.TmaWarpSpecialized1SmMxf4,
EpilogueScheduleType.TmaWarpSpecialized2SmMxf4,
EpilogueScheduleType.TmaWarpSpecialized1SmMxf8f6f4,
EpilogueScheduleType.TmaWarpSpecialized2SmMxf8f6f4,
]
def to_grouped_schedule(schedule, grouped):
@ -1040,7 +1101,8 @@ class OpcodeClass(enum.Enum):
TensorOp = enum_auto()
WmmaTensorOp = enum_auto()
SparseTensorOp = enum_auto()
BlockScaledTensorOp = enum_auto()
BlockScaledTensorOp = enum_auto()
BlockScaledSparseTensorOp = enum_auto()
OpcodeClassNames = {
@ -1048,7 +1110,8 @@ OpcodeClassNames = {
OpcodeClass.TensorOp: 'tensorop',
OpcodeClass.WmmaTensorOp: 'wmma_tensorop',
OpcodeClass.SparseTensorOp: 'sptensorop',
OpcodeClass.BlockScaledTensorOp: 'bstensorop'
OpcodeClass.BlockScaledTensorOp: 'bstensorop',
OpcodeClass.BlockScaledSparseTensorOp: 'bssptensorop'
}
OpcodeClassTag = {
@ -1056,7 +1119,8 @@ OpcodeClassTag = {
OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp',
OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp',
OpcodeClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp',
OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp'
OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp',
OpcodeClass.BlockScaledSparseTensorOp: 'cutlass::arch::OpClassBlockScaledSparseTensorOp'
}
###################################################################################################
@ -1143,6 +1207,7 @@ class GemmKind(enum.Enum):
GroupedBlockScaledUniversal3x = enum_auto()
BlockwiseUniversal3x = enum_auto()
GroupedBlockwiseUniversal3x = enum_auto()
BlockScaledSparseUniversal3x = enum_auto()
#
GemmKindNames = {
@ -1158,7 +1223,8 @@ GemmKindNames = {
GemmKind.GroupedUniversal3x: "gemm_grouped",
GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped",
GemmKind.BlockwiseUniversal3x: "gemm",
GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped"
GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped",
GemmKind.BlockScaledSparseUniversal3x: "spgemm"
}
#