v4.0 update. (#2371)

This commit is contained in:
Junkai-Wu
2025-06-06 14:39:20 +08:00
committed by GitHub
parent 2e2af190bd
commit 8bdbfca682
254 changed files with 29751 additions and 1980 deletions

View File

@ -112,7 +112,6 @@ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):
cuda_version.append(x)
return cuda_version >= [major, minor, patch]
###################################################################################################
###################################################################################################
@ -6769,8 +6768,9 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
},
]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
math_instructions_1sm = [
# tf32 -> f32
MathInstruction(
@ -6793,8 +6793,8 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1], [4,4,1]
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline:
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1]
, DynamicClusterShape
]
@ -6847,7 +6847,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
, DynamicClusterShape
]
@ -6887,8 +6887,9 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
[[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
grouped = is_grouped(gemm_kind)
math_instructions_1sm = [
@ -6950,7 +6951,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1]
, DynamicClusterShape
]
@ -7108,7 +7109,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
, DynamicClusterShape
]
@ -7152,7 +7153,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm
else:
epi_schedule = EpilogueScheduleType.ScheduleAuto
kernel_schedule = KernelScheduleType.TmaWarpSpecialized2SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100
kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped)
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types,
[[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
@ -7201,8 +7202,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
epi_type = DataType.f32
grouped = is_grouped(gemm_kind)
@ -7270,7 +7272,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1]
, DynamicClusterShape
]
@ -7398,11 +7400,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
( data_type["d_type"] == DataType.e5m2 ):
continue
# don't support runtime data type for grouped yet
if grouped and (data_type["a_type"] == DataType.f8 or data_type["b_type"] == DataType.f8):
continue
kernel_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm
kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized1SmSm100, grouped)
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, epi_schedule]],
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
@ -7484,7 +7483,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
, DynamicClusterShape
]
@ -7607,9 +7606,6 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
( data_type["d_type"] == DataType.e5m2 ):
continue
# don't support runtime data type for grouped yet
if grouped and (data_type["a_type"] == DataType.f8 or data_type["b_type"] == DataType.f8):
continue
if grouped:
epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm
@ -7617,7 +7613,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm
else:
epi_schedule = EpilogueScheduleType.ScheduleAuto
kernel_schedule = KernelScheduleType.TmaWarpSpecialized2SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100
kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped)
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
@ -7852,9 +7848,6 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
# grouped GEMM does not support runtime data type yet
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
continue
kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped)
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
@ -7896,8 +7889,9 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
TileSchedulerType.Default, TileSchedulerType.StreamK
]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
epi_type = DataType.f32
math_instructions_1sm = []
@ -7949,7 +7943,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_1sm = [
[2,1,1],
[1,1,1]
@ -8025,7 +8019,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_2sm = [
[2,1,1]
, DynamicClusterShape
@ -8131,8 +8125,9 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
else:
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
epi_type = DataType.f32
math_instructions_1sm = []
@ -8184,7 +8179,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_1sm = [
[1,1,1],
[2,1,1]
@ -8264,7 +8259,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_2sm = [
[2,1,1],
[4,1,1]
@ -8372,6 +8367,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
# layouts for ABC and their alignments.
layouts = [
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]],
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]],
]
instruction_sizes_1sm = [
@ -8400,8 +8396,9 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
else:
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
epi_type = DataType.f32
math_instructions_1sm = []
@ -8416,10 +8413,6 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
# grouped GEMM does not support runtime data type yet
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
continue
math_instructions_1sm.append(
MathInstruction(
instr_size,
@ -8447,10 +8440,6 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
# grouped GEMM does not support runtime data type yet
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
continue
math_instructions_2sm.append(
MathInstruction(
instr_size,
@ -8477,7 +8466,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_1sm = [
[1,1,1],
[2,1,1]
@ -8575,8 +8564,11 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
for layout in layouts:
for data_type in data_types:
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor):
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor):
continue
# E2M1 x E2M1, vector size 32, E8
# E2M1 x E2M1, vector size 16, UE4M3
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
@ -8604,7 +8596,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_2sm = [
[2,1,1],
[4,1,1]
@ -8701,8 +8693,11 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
for layout in layouts:
for data_type in data_types:
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor):
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor):
continue
# E2M1 x E2M1, vector size 32, E8
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
@ -8737,8 +8732,9 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
epi_type = DataType.f32
@ -8763,7 +8759,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1]
, DynamicClusterShape
]
@ -8867,7 +8863,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
, DynamicClusterShape
]
@ -8952,8 +8948,9 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
tile_schedulers = [
TileSchedulerType.Default,
]
@ -9009,7 +9006,7 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_1sm:
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
if cluster_shape == [4,4,1] :
continue
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
@ -9040,7 +9037,7 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_2sm:
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
if cluster_shape == [4,4,1] :
continue
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
@ -9077,8 +9074,9 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
tile_schedulers = [
TileSchedulerType.Default,
]
@ -9134,7 +9132,7 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_1sm:
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
if cluster_shape == [4,4,1] :
continue
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
@ -9165,7 +9163,7 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_2sm:
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
if cluster_shape == [4,4,1] :
continue
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
@ -9202,8 +9200,9 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
tile_schedulers = [
TileSchedulerType.Default,
@ -9259,7 +9258,7 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_1sm:
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
if cluster_shape == [4,4,1] :
continue
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
@ -9290,7 +9289,7 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_2sm:
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
if cluster_shape == [4,4,1] :
continue
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
@ -9327,8 +9326,9 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
tile_schedulers = [
TileSchedulerType.Default,
]
@ -9389,7 +9389,7 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_1sm:
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
if cluster_shape == [4,4,1] :
continue
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
@ -9424,7 +9424,7 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_2sm:
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
if cluster_shape == [4,4,1] :
continue
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
@ -9465,8 +9465,9 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
tile_schedulers = [
TileSchedulerType.Default,
]
@ -9525,7 +9526,7 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_1sm:
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
if cluster_shape == [4,4,1] :
continue
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
@ -9587,7 +9588,7 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_2sm:
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
if cluster_shape == [4,4,1] :
continue
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
@ -9677,8 +9678,9 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
}
]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
math_instructions_1sm = [
MathInstruction(
[128, 256, 8],
@ -9692,7 +9694,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_1sm = [
[1,2,1], [1,1,1], [1,4,1]
, DynamicClusterShape
@ -9732,7 +9734,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_2sm = [
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
, DynamicClusterShape
@ -9770,8 +9772,9 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
[[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
math_instructions_1sm = [
MathInstruction(
[128, 256, 16],
@ -9784,7 +9787,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_1sm = [
[1,2,1], [1,1,1]
, DynamicClusterShape
@ -9858,7 +9861,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_2sm = [
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
, DynamicClusterShape
@ -9931,8 +9934,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
min_cc = 100
max_cc = 101
max_cc = thor_sm
epi_type = DataType.f32
math_instructions_1sm = [
@ -9947,7 +9951,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_1sm = [
[1,2,1], [2,1,1], [1,1,1]
, DynamicClusterShape
@ -10005,7 +10009,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_2sm = [
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
, DynamicClusterShape
@ -10080,8 +10084,9 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
return
thor_sm = 101
minimum_compute_capability = 100
maximum_compute_capability = 101
maximum_compute_capability = thor_sm
spatial_dims = [2, 3]
@ -10110,7 +10115,7 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]]
# tile_descriptions is a 2-level list.
@ -10176,7 +10181,7 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
data_types_and_instruction_shapes_2sm)
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]]
for math_inst, output_type in math_instructions_w_output_2sm:
@ -10233,8 +10238,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
return
thor_sm = 101
minimum_compute_capability = 100
maximum_compute_capability = 101
maximum_compute_capability = thor_sm
spatial_dims = [2, 3]
stages = 0 # zero means "deduce the number of stages automatically"
@ -10258,7 +10264,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
data_types_and_instruction_shapes_1sm)
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]]
for math_inst, output_type in math_instructions_w_output_1sm:
@ -10323,7 +10329,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
data_types_and_instruction_shapes_2sm)
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]]
if 101 in manifest.compute_capabilities :
if thor_sm in manifest.compute_capabilities_baseline :
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]]
for math_inst, output_type in math_instructions_w_output_2sm:
@ -10704,9 +10710,9 @@ def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version):
ab_types_mxf8f6f4 = [
DataType.e2m1,
DataType.e2m3,
#DataType.e2m3,
DataType.e3m2,
DataType.e5m2,
#DataType.e5m2,
DataType.e4m3,
]
@ -10783,13 +10789,145 @@ def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version):
tile_schedulers = tile_schedulers(kernel_schedule),
gemm_kind = GemmKind.SparseUniversal3x)
def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
layouts = [
[[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 16]],
[[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 16]]
]
cooperative_tile_sizes = [
[128, 128, 128]
]
pingpong_tile_sizes = [
[64, 128, 128]
]
def get_tile_sizes(kernel_scheduler):
if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120:
return pingpong_tile_sizes
return cooperative_tile_sizes
def get_warp_count(kernel_scheduler):
if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120:
return [2, 2, 1]
return [4, 2, 1]
def get_sf_sizes(tile_size):
sf_sizes = []
for vec_m in [1, 128]:
if tile_size[0] % vec_m > 0:
continue
for vec_n in [1, 128]:
if tile_size[1] % vec_m > 0:
continue
sf_sizes.append(
[vec_m, vec_n, 128]
)
return sf_sizes
cluster_shape = [1,1,1]
acc_types = [ DataType.f32 ]
instruction_sizes = [
[16, 8, 32]
]
def tile_schedulers(kernel_schedule):
return [TileSchedulerType.Default]
min_cc = 120
max_cc = 120
kernel_schedulers = [
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120,
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120
]
ab_types = [
[DataType.e4m3, DataType.e4m3],
[DataType.e4m3, DataType.e5m2]
]
math_instructions = []
for instr_size, ab_type, acc_type in product(instruction_sizes, ab_types, acc_types):
a_type, b_type = ab_type
math_instructions.append(
MathInstruction(
instr_size,
a_type, b_type, acc_type,
OpcodeClass.TensorOp,
MathOperation.multiply_add)
)
# Create gemm operator for mxf8f6f4
for kernel_schedule in kernel_schedulers:
tile_sizes = get_tile_sizes(kernel_schedule)
warp_count = get_warp_count(kernel_schedule)
for math_inst in math_instructions:
tile_descriptions = []
for tile_size in tile_sizes:
sf_sizes = get_sf_sizes(tile_size)
for sf_size in sf_sizes:
tile_descriptions.append(
TileDescription(tile_size, 0, warp_count, math_inst, min_cc, max_cc, cluster_shape,
explicit_vector_sizes=sf_size)
)
data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.f16,
"d_type" : DataType.f16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : DataType.f32
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.bf16,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : DataType.f32
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.f16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : DataType.f32
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : DataType.f32
}
]
for data_type in data_types:
# Set alignment d based on Destination format
for layout in layouts:
layout[2][1] = int(128 // DataTypeSize[data_type["d_type"]])
# Create gemm operator
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
tile_schedulers = tile_schedulers(kernel_schedule),
gemm_kind = gemm_kind)
def GenerateSM100(manifest, cuda_version):
arch_family_cc = ['100f', '101f']
#
# Dense Gemm
#
architectures = manifest.args.architectures.split(';') if len(args.architectures) else ['50',]
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version)
GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version)
@ -10797,7 +10935,7 @@ def GenerateSM100(manifest, cuda_version):
GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version)
if '100f' not in architectures and '101f' not in architectures:
if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)):
GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version)
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version)
@ -10819,7 +10957,7 @@ def GenerateSM100(manifest, cuda_version):
#
GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version)
GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version)
if '100f' not in architectures and '101f' not in architectures:
if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)):
GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version)
GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version)
GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version)
@ -10849,6 +10987,8 @@ def GenerateSM120(manifest, cuda_version):
# Sparse Gemm
#
GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version)
GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version)
GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x)
###################################################################################################
@ -11328,13 +11468,17 @@ if __name__ == "__main__":
GenerateSM80(manifest, args.cuda_version)
GenerateSM89(manifest, args.cuda_version)
GenerateSM90(manifest, args.cuda_version)
blackwell_enabled_arch = any(arch in ["100a", "100f", "101a", "101f", "120a", "120f"] for arch in archs)
blackwell_arch_list = [
"100a", "100f",
"101a", "101f",
"120a", "120f"
]
blackwell_enabled_arch = any(arch in blackwell_arch_list for arch in archs)
if blackwell_enabled_arch:
GenerateSM100(manifest, args.cuda_version)
GenerateSM120(manifest, args.cuda_version)
if 'library' in args.generator_target.split(','):
manifest.emit(GeneratorTarget.Library)