v4.0 update. (#2371)
This commit is contained in:
@ -112,7 +112,6 @@ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):
|
||||
cuda_version.append(x)
|
||||
return cuda_version >= [major, minor, patch]
|
||||
|
||||
|
||||
###################################################################################################
|
||||
###################################################################################################
|
||||
|
||||
@ -6769,8 +6768,9 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
},
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
math_instructions_1sm = [
|
||||
# tf32 -> f32
|
||||
MathInstruction(
|
||||
@ -6793,8 +6793,8 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
|
||||
if thor_sm in manifest.compute_capabilities_baseline:
|
||||
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
@ -6847,7 +6847,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
@ -6887,8 +6887,9 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
[[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
grouped = is_grouped(gemm_kind)
|
||||
|
||||
math_instructions_1sm = [
|
||||
@ -6950,7 +6951,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
@ -7108,7 +7109,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
@ -7152,7 +7153,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm
|
||||
else:
|
||||
epi_schedule = EpilogueScheduleType.ScheduleAuto
|
||||
kernel_schedule = KernelScheduleType.TmaWarpSpecialized2SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100
|
||||
kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped)
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types,
|
||||
[[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
|
||||
@ -7201,8 +7202,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
epi_type = DataType.f32
|
||||
grouped = is_grouped(gemm_kind)
|
||||
|
||||
@ -7270,7 +7272,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
@ -7398,11 +7400,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
|
||||
( data_type["d_type"] == DataType.e5m2 ):
|
||||
continue
|
||||
# don't support runtime data type for grouped yet
|
||||
if grouped and (data_type["a_type"] == DataType.f8 or data_type["b_type"] == DataType.f8):
|
||||
continue
|
||||
kernel_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100
|
||||
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm
|
||||
kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized1SmSm100, grouped)
|
||||
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[kernel_schedule, epi_schedule]],
|
||||
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
|
||||
@ -7484,7 +7483,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
@ -7607,9 +7606,6 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
|
||||
( data_type["d_type"] == DataType.e5m2 ):
|
||||
continue
|
||||
# don't support runtime data type for grouped yet
|
||||
if grouped and (data_type["a_type"] == DataType.f8 or data_type["b_type"] == DataType.f8):
|
||||
continue
|
||||
|
||||
if grouped:
|
||||
epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm
|
||||
@ -7617,7 +7613,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm
|
||||
else:
|
||||
epi_schedule = EpilogueScheduleType.ScheduleAuto
|
||||
kernel_schedule = KernelScheduleType.TmaWarpSpecialized2SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100
|
||||
kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped)
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
|
||||
@ -7852,9 +7848,6 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
# grouped GEMM does not support runtime data type yet
|
||||
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
|
||||
continue
|
||||
kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped)
|
||||
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
@ -7896,8 +7889,9 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
TileSchedulerType.Default, TileSchedulerType.StreamK
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = []
|
||||
@ -7949,7 +7943,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_1sm = [
|
||||
[2,1,1],
|
||||
[1,1,1]
|
||||
@ -8025,7 +8019,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1]
|
||||
, DynamicClusterShape
|
||||
@ -8131,8 +8125,9 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
else:
|
||||
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = []
|
||||
@ -8184,7 +8179,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_1sm = [
|
||||
[1,1,1],
|
||||
[2,1,1]
|
||||
@ -8264,7 +8259,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1],
|
||||
[4,1,1]
|
||||
@ -8372,6 +8367,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
# layouts for ABC and their alignments.
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]],
|
||||
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]],
|
||||
]
|
||||
|
||||
instruction_sizes_1sm = [
|
||||
@ -8400,8 +8396,9 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
else:
|
||||
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = []
|
||||
@ -8416,10 +8413,6 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
# grouped GEMM does not support runtime data type yet
|
||||
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_1sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
@ -8447,10 +8440,6 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
# grouped GEMM does not support runtime data type yet
|
||||
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_2sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
@ -8477,7 +8466,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_1sm = [
|
||||
[1,1,1],
|
||||
[2,1,1]
|
||||
@ -8575,8 +8564,11 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
|
||||
for layout in layouts:
|
||||
for data_type in data_types:
|
||||
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
|
||||
if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor):
|
||||
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
|
||||
if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor):
|
||||
continue
|
||||
|
||||
# E2M1 x E2M1, vector size 32, E8
|
||||
# E2M1 x E2M1, vector size 16, UE4M3
|
||||
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
|
||||
@ -8604,7 +8596,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1],
|
||||
[4,1,1]
|
||||
@ -8701,8 +8693,11 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
|
||||
for layout in layouts:
|
||||
for data_type in data_types:
|
||||
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
|
||||
if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor):
|
||||
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
|
||||
if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor):
|
||||
continue
|
||||
|
||||
# E2M1 x E2M1, vector size 32, E8
|
||||
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
|
||||
|
||||
@ -8737,8 +8732,9 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
@ -8763,7 +8759,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
@ -8867,7 +8863,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
@ -8952,8 +8948,9 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@ -9009,7 +9006,7 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
@ -9040,7 +9037,7 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
@ -9077,8 +9074,9 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@ -9134,7 +9132,7 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
@ -9165,7 +9163,7 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
@ -9202,8 +9200,9 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
@ -9259,7 +9258,7 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
@ -9290,7 +9289,7 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
@ -9327,8 +9326,9 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@ -9389,7 +9389,7 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
@ -9424,7 +9424,7 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
@ -9465,8 +9465,9 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@ -9525,7 +9526,7 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
@ -9587,7 +9588,7 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
@ -9677,8 +9678,9 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
}
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
math_instructions_1sm = [
|
||||
MathInstruction(
|
||||
[128, 256, 8],
|
||||
@ -9692,7 +9694,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
@ -9732,7 +9734,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
|
||||
, DynamicClusterShape
|
||||
@ -9770,8 +9772,9 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
[[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
math_instructions_1sm = [
|
||||
MathInstruction(
|
||||
[128, 256, 16],
|
||||
@ -9784,7 +9787,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [1,1,1]
|
||||
, DynamicClusterShape
|
||||
@ -9858,7 +9861,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
|
||||
, DynamicClusterShape
|
||||
@ -9931,8 +9934,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
max_cc = thor_sm
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = [
|
||||
@ -9947,7 +9951,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [2,1,1], [1,1,1]
|
||||
, DynamicClusterShape
|
||||
@ -10005,7 +10009,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
|
||||
, DynamicClusterShape
|
||||
@ -10080,8 +10084,9 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
||||
return
|
||||
|
||||
thor_sm = 101
|
||||
minimum_compute_capability = 100
|
||||
maximum_compute_capability = 101
|
||||
maximum_compute_capability = thor_sm
|
||||
|
||||
spatial_dims = [2, 3]
|
||||
|
||||
@ -10110,7 +10115,7 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
|
||||
|
||||
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]]
|
||||
|
||||
# tile_descriptions is a 2-level list.
|
||||
@ -10176,7 +10181,7 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
|
||||
data_types_and_instruction_shapes_2sm)
|
||||
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]]
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]]
|
||||
|
||||
for math_inst, output_type in math_instructions_w_output_2sm:
|
||||
@ -10233,8 +10238,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
||||
return
|
||||
|
||||
thor_sm = 101
|
||||
minimum_compute_capability = 100
|
||||
maximum_compute_capability = 101
|
||||
maximum_compute_capability = thor_sm
|
||||
|
||||
spatial_dims = [2, 3]
|
||||
stages = 0 # zero means "deduce the number of stages automatically"
|
||||
@ -10258,7 +10264,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
|
||||
data_types_and_instruction_shapes_1sm)
|
||||
|
||||
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]]
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]]
|
||||
|
||||
for math_inst, output_type in math_instructions_w_output_1sm:
|
||||
@ -10323,7 +10329,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
|
||||
data_types_and_instruction_shapes_2sm)
|
||||
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]]
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if thor_sm in manifest.compute_capabilities_baseline :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]]
|
||||
|
||||
for math_inst, output_type in math_instructions_w_output_2sm:
|
||||
@ -10704,9 +10710,9 @@ def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version):
|
||||
|
||||
ab_types_mxf8f6f4 = [
|
||||
DataType.e2m1,
|
||||
DataType.e2m3,
|
||||
#DataType.e2m3,
|
||||
DataType.e3m2,
|
||||
DataType.e5m2,
|
||||
#DataType.e5m2,
|
||||
DataType.e4m3,
|
||||
]
|
||||
|
||||
@ -10783,13 +10789,145 @@ def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version):
|
||||
tile_schedulers = tile_schedulers(kernel_schedule),
|
||||
gemm_kind = GemmKind.SparseUniversal3x)
|
||||
|
||||
def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x):
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 16]],
|
||||
[[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 16]]
|
||||
]
|
||||
|
||||
cooperative_tile_sizes = [
|
||||
[128, 128, 128]
|
||||
]
|
||||
pingpong_tile_sizes = [
|
||||
[64, 128, 128]
|
||||
]
|
||||
|
||||
def get_tile_sizes(kernel_scheduler):
|
||||
if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120:
|
||||
return pingpong_tile_sizes
|
||||
return cooperative_tile_sizes
|
||||
|
||||
def get_warp_count(kernel_scheduler):
|
||||
if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120:
|
||||
return [2, 2, 1]
|
||||
return [4, 2, 1]
|
||||
|
||||
def get_sf_sizes(tile_size):
|
||||
sf_sizes = []
|
||||
for vec_m in [1, 128]:
|
||||
if tile_size[0] % vec_m > 0:
|
||||
continue
|
||||
for vec_n in [1, 128]:
|
||||
if tile_size[1] % vec_m > 0:
|
||||
continue
|
||||
sf_sizes.append(
|
||||
[vec_m, vec_n, 128]
|
||||
)
|
||||
return sf_sizes
|
||||
|
||||
cluster_shape = [1,1,1]
|
||||
|
||||
acc_types = [ DataType.f32 ]
|
||||
|
||||
instruction_sizes = [
|
||||
[16, 8, 32]
|
||||
]
|
||||
|
||||
def tile_schedulers(kernel_schedule):
|
||||
return [TileSchedulerType.Default]
|
||||
|
||||
min_cc = 120
|
||||
max_cc = 120
|
||||
|
||||
kernel_schedulers = [
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120,
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120
|
||||
]
|
||||
|
||||
ab_types = [
|
||||
[DataType.e4m3, DataType.e4m3],
|
||||
[DataType.e4m3, DataType.e5m2]
|
||||
]
|
||||
|
||||
math_instructions = []
|
||||
|
||||
for instr_size, ab_type, acc_type in product(instruction_sizes, ab_types, acc_types):
|
||||
a_type, b_type = ab_type
|
||||
math_instructions.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
a_type, b_type, acc_type,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add)
|
||||
)
|
||||
|
||||
# Create gemm operator for mxf8f6f4
|
||||
for kernel_schedule in kernel_schedulers:
|
||||
tile_sizes = get_tile_sizes(kernel_schedule)
|
||||
warp_count = get_warp_count(kernel_schedule)
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = []
|
||||
for tile_size in tile_sizes:
|
||||
sf_sizes = get_sf_sizes(tile_size)
|
||||
for sf_size in sf_sizes:
|
||||
tile_descriptions.append(
|
||||
TileDescription(tile_size, 0, warp_count, math_inst, min_cc, max_cc, cluster_shape,
|
||||
explicit_vector_sizes=sf_size)
|
||||
)
|
||||
|
||||
data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.f16,
|
||||
"d_type" : DataType.f16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : DataType.f32
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.bf16,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : DataType.f32
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.f16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : DataType.f32
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : DataType.f32
|
||||
}
|
||||
]
|
||||
|
||||
for data_type in data_types:
|
||||
# Set alignment d based on Destination format
|
||||
for layout in layouts:
|
||||
layout[2][1] = int(128 // DataTypeSize[data_type["d_type"]])
|
||||
# Create gemm operator
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
|
||||
tile_schedulers = tile_schedulers(kernel_schedule),
|
||||
gemm_kind = gemm_kind)
|
||||
|
||||
def GenerateSM100(manifest, cuda_version):
|
||||
arch_family_cc = ['100f', '101f']
|
||||
#
|
||||
# Dense Gemm
|
||||
#
|
||||
architectures = manifest.args.architectures.split(';') if len(args.architectures) else ['50',]
|
||||
|
||||
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version)
|
||||
|
||||
GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version)
|
||||
@ -10797,7 +10935,7 @@ def GenerateSM100(manifest, cuda_version):
|
||||
|
||||
GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version)
|
||||
|
||||
if '100f' not in architectures and '101f' not in architectures:
|
||||
if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)):
|
||||
GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version)
|
||||
|
||||
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version)
|
||||
@ -10819,7 +10957,7 @@ def GenerateSM100(manifest, cuda_version):
|
||||
#
|
||||
GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version)
|
||||
if '100f' not in architectures and '101f' not in architectures:
|
||||
if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)):
|
||||
GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version)
|
||||
@ -10849,6 +10987,8 @@ def GenerateSM120(manifest, cuda_version):
|
||||
# Sparse Gemm
|
||||
#
|
||||
GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version)
|
||||
GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version)
|
||||
GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
@ -11328,13 +11468,17 @@ if __name__ == "__main__":
|
||||
GenerateSM80(manifest, args.cuda_version)
|
||||
GenerateSM89(manifest, args.cuda_version)
|
||||
GenerateSM90(manifest, args.cuda_version)
|
||||
|
||||
blackwell_enabled_arch = any(arch in ["100a", "100f", "101a", "101f", "120a", "120f"] for arch in archs)
|
||||
|
||||
blackwell_arch_list = [
|
||||
"100a", "100f",
|
||||
"101a", "101f",
|
||||
"120a", "120f"
|
||||
]
|
||||
blackwell_enabled_arch = any(arch in blackwell_arch_list for arch in archs)
|
||||
if blackwell_enabled_arch:
|
||||
GenerateSM100(manifest, args.cuda_version)
|
||||
GenerateSM120(manifest, args.cuda_version)
|
||||
|
||||
|
||||
if 'library' in args.generator_target.split(','):
|
||||
manifest.emit(GeneratorTarget.Library)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user