v4.2 release. (#2587)

* Fix default cluster callback values to 1 to avoid profiler failure when these values are not set in command line.

* v4.2 release.
This commit is contained in:
Junkai-Wu
2025-08-23 06:11:24 +08:00
committed by GitHub
parent 11cad1f67b
commit a49a78ffef
351 changed files with 28182 additions and 2032 deletions

View File

@ -49,7 +49,6 @@ from . import rank_2k_operation
from . import rank_k_operation
from . import symm_operation
from . import trmm_operation
# Make enum types from library.py accessible via cutlass_library.*
from .library import *

View File

@ -279,7 +279,7 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups
def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
):
# For functional testing, we prefer to run reference computing on device if any
reference_device_archs = ["100a"]
reference_device_archs = ["100a", "103a"]
run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False
profiler_flags_for_verification = "device" if run_reference_on_device else "host"
@ -287,7 +287,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
# TODO: randomize beta values for wider coverage
beta_values = [0.5]
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "120a", "120f"])
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"])
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
@ -306,6 +306,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
'bf16gemm_f32_f32_f32_f32_f32',
]
exclude_archs = arch not in ("103a")
if exclude_archs:
sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8')
sm100_mma_data_type_runtime_dtype = [
'gemm.*f4_f4_f32_f32_f32',
'gemm.*f6_f6_f32_f32_f32',
@ -344,6 +348,11 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
]
sm103_block_scaled_data_type = [
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
]
block_scaled_cluster_size = [
'4x4x1', '2x1x1',
'0x0x1' # dynamic cluster
@ -354,6 +363,9 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
if arch in ["100a", "100f"]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
@ -361,15 +373,23 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch in ["101a", "101f",
]:
elif arch in ["101a", "101f", "110a", "110f"]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch in ["120a", "120f"]:
elif arch in ["103a"]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})|" \
f"({sm103_block_scaled_filter_regex_1sm})|" \
f"({sm103_block_scaled_filter_regex_2sm})"
elif arch in ["120a", "120f", "121a", "121f"]:
# blockscaled sm120_mma kernels
blockscaled_sm120_mma_kernel_cta_tiles = [
@ -384,7 +404,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
else:
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f"
raise Exception(error_message)
elif mode == "functional_L1":
@ -403,16 +423,27 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
]
block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1']
sm103_block_scaled_data_type = [
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
]
block_scaled_cluster_size = ['0x0x1']
block_scaled_layouts = ['tnt']
# regex list must be in kernel procedural name order
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
f"({block_scaled_filter_regex_2sm})" \
f"({sm103_block_scaled_filter_regex_1sm})|" \
f"({sm103_block_scaled_filter_regex_2sm})"
# CTA tiles for sm120 MMA - only run one tile size to reduce build/test times
sm120_mma_kernel_cta_tiles = [
# h1688, s1688, i16832, i8816
@ -449,7 +480,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
problem_waves = [0.5, 1.25, 2.5]
kernel_filter = f"({filter_regex_sm100_mma})|({filter_regex_sm120_mma})"
if arch in ["120a", "120f", "121a", "121f"]:
kernel_filter = f"({filter_regex_sm120_mma})"
else:
kernel_filter = f"({filter_regex_sm100_mma})"
else:
raise ValueError()

View File

@ -341,7 +341,7 @@ class GemmOperation:
Get the tile shape passed to the collective builder.
On Blackwell, this is different than the operation.tile_description.tile_shape.
"""
is_sm100_kernel = (self.arch == 100)
is_sm100_kernel = (self.arch == 100 or self.arch == 103)
if not is_sm100_kernel:
return self.tile_description.tile_shape
@ -995,6 +995,24 @@ ${compile_guard_end}
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103:
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103:
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103:
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103:
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>'
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'

View File

@ -90,10 +90,12 @@ try:
raise ImportError("Disabling attempt to import cutlass_library")
from cutlass_library.library import *
from cutlass_library.manifest import *
from cutlass_library.heuristics import *
from cutlass_library.emit_kernel_listing import emit_gemm_kernel_testlist
except ImportError:
from library import *
from manifest import *
from heuristics import *
from emit_kernel_listing import emit_gemm_kernel_testlist
###################################################################################################
@ -112,6 +114,10 @@ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):
cuda_version.append(x)
return cuda_version >= [major, minor, patch]
# From cuda 13.0, Thor SM is renumbered from 101 to 110
def ThorSMRenumbering(cuda_version):
return 110 if CudaToolkitVersionSatisfies(cuda_version, 13, 0) else 101
###################################################################################################
###################################################################################################
@ -6768,9 +6774,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
},
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
math_instructions_1sm = [
# tf32 -> f32
MathInstruction(
@ -6887,7 +6895,8 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
[[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
grouped = is_grouped(gemm_kind)
@ -7202,9 +7211,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
epi_type = DataType.f32
grouped = is_grouped(gemm_kind)
@ -7889,9 +7900,11 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
TileSchedulerType.Default, TileSchedulerType.StreamK
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
epi_type = DataType.f32
math_instructions_1sm = []
@ -8092,6 +8105,8 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
grouped = is_grouped(gemm_kind)
layouts = [
[[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]],
[[LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 128], [LayoutType.RowMajor, 0]],
@ -8120,14 +8135,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
def tile_schedulers(sfdtype):
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
# the epilogue is the traditional linear combination, for which we already have tests with stream-K.
if sfdtype["type"] == DataType.void:
if sfdtype["type"] == DataType.void or grouped:
return [TileSchedulerType.Default]
else:
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
epi_type = DataType.f32
math_instructions_1sm = []
@ -8209,6 +8226,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
@ -8246,7 +8273,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
for data_type in data_types:
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]]
[[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]]
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
cluster_shapes_2sm = [
@ -8288,6 +8315,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
@ -8346,7 +8383,11 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[1][1] != 0:
continue
if math_inst.instruction_shape[0] == 128:
if grouped:
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
[[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]]
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
elif math_inst.instruction_shape[0] == 128:
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]]
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
@ -8396,9 +8437,11 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
else:
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
epi_type = DataType.f32
math_instructions_1sm = []
@ -8496,6 +8539,16 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.bf16,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
@ -8625,6 +8678,16 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.bf16,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
@ -8715,6 +8778,230 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
# SM100 MMA with F4 + block scale
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
# layouts for ABC and their alignments.
layouts = [
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]],
]
instruction_sizes_1sm = [
[128, 128, 96],
]
instruction_sizes_2sm = [
[256, 128, 96],
]
ab_types = [
DataType.e2m1,
]
acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions
min_cc = 103
max_cc = 103
epi_type = DataType.f32
math_instructions_1sm = []
is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8)
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types):
is_runtime_datatype_a = is_runtime_datatype(a_type)
is_runtime_datatype_b = is_runtime_datatype(b_type)
# A/B datatypes should be both static or dynamic
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
math_instructions_1sm.append(
MathInstruction(
instr_size,
a_type, b_type, acc_type,
OpcodeClass.BlockScaledTensorOp,
MathOperation.multiply_add,
DataType.ue8m0) # UE8M0 scale factor
)
math_instructions_2sm = []
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types):
is_runtime_datatype_a = is_runtime_datatype(a_type)
is_runtime_datatype_b = is_runtime_datatype(b_type)
# A/B datatypes should be both static or dynamic
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
math_instructions_2sm.append(
MathInstruction(
instr_size,
a_type, b_type, acc_type,
OpcodeClass.BlockScaledTensorOp,
MathOperation.multiply_add,
DataType.ue8m0) # UE8M0 scale factor
)
cluster_shapes_1sm = [
[1,1,1],
# [1,2,1],
[2,1,1],
# [1,4,1],
[4,4,1]
, DynamicClusterShape
]
# 1xSM MMA kernels
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_1sm:
multiplier_1sm = cluster_shape
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_1sm[0],
math_inst.instruction_shape[1] * multiplier_1sm[1],
768],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.bf16,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.e2m1,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor}
},
]
for layout in layouts:
for data_type in data_types:
# Set alignment d based on Destination format.
if DataTypeSize[data_type["c_type"]] == 0 :
layout[2][1] = 256 // DataTypeSize[data_type["d_type"]]
else:
layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]])
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
# E2M1 x E2M1, vector size 32, E8
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized1Sm]
fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm]
fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm]
# For FP4 inputs
if isFp4:
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch
,fp4_schedule_enable_prefetch
]
, gemm_kind=gemm_kind
)
cluster_shapes_2sm = [
[2,1,1],
# [2,2,1],
# [2,4,1],
[4,1,1],
# [4,2,1],
[4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
multiplier_2sm = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_2sm[0],
math_inst.instruction_shape[1] * multiplier_2sm[1],
math_inst.instruction_shape[2] * 8 * multiplier_2sm[2]],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.bf16,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.e2m1,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor}
},
]
for layout in layouts:
for data_type in data_types:
# Set alignment d based on Destination format.
if DataTypeSize[data_type["c_type"]] == 0 :
layout[2][1] = 256 // DataTypeSize[data_type["d_type"]]
else:
layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]])
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
# E2M1 x E2M1, vector size 32, E8
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
# For FP4 inputs
if isFp4:
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch
,fp4_schedule_enable_prefetch
]
, gemm_kind=gemm_kind
)
def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
@ -8732,7 +9019,8 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
@ -8948,9 +9236,11 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
tile_schedulers = [
TileSchedulerType.Default,
]
@ -9074,9 +9364,11 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
tile_schedulers = [
TileSchedulerType.Default,
]
@ -9200,7 +9492,8 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
@ -9326,9 +9619,11 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
tile_schedulers = [
TileSchedulerType.Default,
]
@ -9465,9 +9760,11 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
tile_schedulers = [
TileSchedulerType.Default,
]
@ -9678,9 +9975,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
}
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
math_instructions_1sm = [
MathInstruction(
[128, 256, 8],
@ -9772,9 +10071,11 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
[[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
math_instructions_1sm = [
MathInstruction(
[128, 256, 16],
@ -9934,9 +10235,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
epi_type = DataType.f32
math_instructions_1sm = [
@ -10084,7 +10387,8 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
return
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
minimum_compute_capability = 100
maximum_compute_capability = thor_sm
@ -10238,7 +10542,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
return
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
minimum_compute_capability = 100
maximum_compute_capability = thor_sm
@ -10422,7 +10727,7 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
min_cc = 120
max_cc = 120
max_cc = 121
epi_type = DataType.f32
@ -10567,7 +10872,7 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
min_cc = 120
max_cc = 120
max_cc = 121
epi_type = DataType.f32
@ -10720,7 +11025,7 @@ def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version):
return [TileSchedulerType.Default]
min_cc = 120
max_cc = 120
max_cc = 121
kernel_schedules = [
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120,
@ -10840,7 +11145,7 @@ def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
return [TileSchedulerType.Default]
min_cc = 120
max_cc = 120
max_cc = 121
kernel_schedulers = [
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120,
@ -10924,7 +11229,11 @@ def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
gemm_kind = gemm_kind)
def GenerateSM100(manifest, cuda_version):
arch_family_cc = ['100f', '101f']
arch_family_cc = ['100f', '101f', '103a']
if CudaToolkitVersionSatisfies(cuda_version, 13, 0):
for old_cc, new_cc in [('101f', '110f')]:
arch_family_cc = [cc.replace(old_cc, new_cc) for cc in arch_family_cc]
#
# Dense Gemm
#
@ -10966,8 +11275,11 @@ def GenerateSM100(manifest, cuda_version):
# Block Scaled Gemm
#
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version)
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version)
#
# Conv
#
@ -11413,7 +11725,6 @@ def numeric_log_level(log_level: str) -> int:
raise ValueError(f'Invalid log level: {log_level}')
return numeric_level
# This function for defining the ArgumentParser is used to make it easy for the CUTLASS Python interface
# to leverage the functionality in this file without running this script via a shell prompt.
def define_parser():
@ -11438,6 +11749,11 @@ def define_parser():
parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose compute capability range is not satisfied by the build target are excluded.')
parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit")
parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file')
parser.add_argument('--heuristics-problems-file', type=str, default=None, required=False, help='Full path of heuristics problem size description file, as a json list')
parser.add_argument('--heuristics-testlist-file', type=str, default=None, required=False, help='Full path of heuristics testlist CSV file, to be passed to cutlass_profiler')
parser.add_argument('--heuristics-gpu', type=str, default=None, required=False, help='GPU to use for evaluating heuristics offline. None or `auto` to autodetect using cuda', choices=['', 'auto', 'H100_SXM', 'H100_PCIE', 'H100_NVL', 'H200_SXM', 'H20_SXM', 'B200', 'GB200_NVL', 'RTX_5080', 'RTX_5090', 'RTX_PRO_6000'])
parser.add_argument('--heuristics-configs-per-problem', type=int, default=10, required=False, help='Number of kernel configs to generate for each problem in the problem list')
parser.add_argument('--heuristics-restrict-kernels', action='store_true', help='Restrict heuristics mode to use only the default set of kernels emitted by generator.py')
parser.add_argument('--selected-kernel-list', type=str, default=None, required=False,
help='Specify the output log file containing all enabled kernels in this build')
parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels")
@ -11460,6 +11776,9 @@ if __name__ == "__main__":
archs = args.architectures.split(';')
if args.heuristics_problems_file:
filter_manifest_and_write_heuristics_file(manifest, args)
GenerateSM50(manifest, args.cuda_version)
GenerateSM60(manifest, args.cuda_version)
GenerateSM61(manifest, args.cuda_version)
@ -11468,17 +11787,20 @@ if __name__ == "__main__":
GenerateSM80(manifest, args.cuda_version)
GenerateSM89(manifest, args.cuda_version)
GenerateSM90(manifest, args.cuda_version)
blackwell_arch_list = [
"100a", "100f",
"101a", "101f",
"120a", "120f"
"103a", "103f",
"110a", "110f",
"120a", "120f",
"121a", "121f",
]
blackwell_enabled_arch = any(arch in blackwell_arch_list for arch in archs)
if blackwell_enabled_arch:
GenerateSM100(manifest, args.cuda_version)
GenerateSM120(manifest, args.cuda_version)
if 'library' in args.generator_target.split(','):
manifest.emit(GeneratorTarget.Library)

View File

@ -0,0 +1,414 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utilities for selecting CUTLASS library kernels based on problem description
"""
import json
import csv
try:
if CUTLASS_IGNORE_PACKAGE:
raise ImportError("Disabling attempt to import cutlass_library")
from cutlass_library.library import *
from cutlass_library.generator import *
from cutlass_library.heuristics_provider import *
except ImportError:
from library import *
from generator import *
from heuristics_provider import *
try:
from .sm90_utils import (
get_valid_schedules,
generate_data_types_from_math_instruction,
fix_alignments,
)
except ImportError:
from sm90_utils import (
get_valid_schedules,
generate_data_types_from_math_instruction,
fix_alignments,
)
_LOGGER = logging.getLogger(__name__)
dtype_map = {v: k for k, v in DataTypeNames.items()}
def serialize_heuristics_results_to_json(problems_with_configs, outfile_path):
"""
Utilitiy function to write heuristics results to a json file for debug
args:
problems_with_configs: List of problems provided to the heuristic, with a list of operations added to each problem dict
outfile_path: Outfile path
returns:
None
"""
pc_copy = problems_with_configs.copy()
for p in pc_copy:
for k, v in p.items():
if isinstance(v, DataType):
p[k] = DataTypeNames[v]
elif isinstance(v, LayoutType):
p[k] = ShortLayoutTypeNames[v]
configs = p['configs']
for c in configs:
for k, v in c.items():
if isinstance(v, DataType):
c[k] = DataTypeNames[v]
elif isinstance(v, LayoutType):
c[k] = ShortLayoutTypeNames[v]
with open(outfile_path, 'w') as f:
json.dump(pc_copy, f, indent=2)
def get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, voidC=False, use_fast_acc=True, count=1, provider=None):
"""
Get heuristic-suggested GEMM kernel configurations for a single GEMM problem.
args:
m, n, k: GEMM dimensions
batch_count: batch count
layouts: tuple of layouts of type LayoutType
use_fast_acc: Use fast accumulation for FP8. Ignored for other precisions
count: Number of configs to return
provider: Heuristics provider to use
returns:
A list of dictionaries containing the suggested kernel configurations and additional info from the input required to define a Cutlass GemmOperation, with the following keys:
- 'cta_tile_m', 'cta_tile_m', 'cta_tile_k': CTA tile size
- 'instr_tile_m', 'instr_tile_n', 'instr_tile_k': Instruction tile size
- 'stages': kernel pipeline stage count
- 'cluster_m', 'cluster_n', 'cluster_k': cluster size
- 'layout_a', 'layout_b': input tensor layouts of type LayoutType
- 'alignment_a', 'alignment_b': input tensor alignments, in count of elements
- 'dtype_a', 'dtype_b', 'dtype_acc': dtypes of a, b, and accumulator, of type DataType
- 'swizzle_size' : suggested threadblock swizzle
- 'split_k_slices': number of partitions of the k dimension for splitK
- 'raster_order': raster order for CTAs over output tiles ('along_m' or 'along_n')
"""
if provider is None:
provider = MatmulHeuristics()
return provider.get_configs(m, n, k, batch_count, dtypes, layouts, alignment_a, alignment_b, voidC=voidC, use_fast_acc=use_fast_acc, count=count)
def get_gemm_configs(problems, provider=None, count=1):
"""
Get heuristic-suggested GEMM kernel configurations for a set of GEMM problems.
args:
problems: List of dictionaries describing GEMM problems with the following keys:
- 'm', 'n', 'k': Matrix dimensions (required)
- 'dtype_a': Data type of matrix A (required)
- 'dtype_b': Data type of matrix B (required)
- 'dtype_c': Data type of matrix C (default: None)
- 'dtype_d': Data type of matrix D (required)
- 'dtype_acc': Compute data type (default 'f32')
- 'layout': Operation layout (e.g. 'tnt')
- 'alignment_a': Memory access granularity of A, in units of elements (default: 16 bytes equivalent elements)
- 'alignment_b': Memory access granularity of B, in units of elements (default: 16 bytes equivalent elements)
- 'alpha': Scalar multiplier for A*B (default: 1.0)
- 'beta': Scalar multiplier for C (default: 0.0)
- 'batch_count': Number of GEMM operations in batch (default: 1)
- 'use_fast_acc': Enable fast accumulation for FP8 on Hopper (default: True)
provider: Heuristics provider to use
count: Number of configurations to return per problem (defualt: 1)
returns:
A copy of the input dictionary, with key `configs` added containing the selected gemm configs
"""
ret = []
for problem in problems:
problem = problem.copy()
try:
m = problem['m']
n = problem['n']
k = problem['k']
dtype_a = problem['dtype_a']
dtype_b = problem['dtype_b']
dtype_d = problem['dtype_d']
layout = problem['layout']
except KeyError as e:
_LOGGER.error(f"Missing required parameter {e} for problem {problem}")
raise
operation = problem.get('operation', 'gemm')
batch_count = problem.get('batch_count', 1)
dtype_acc = problem.get('dtype_acc', 'f32')
dtype_c = problem.get('dtype_c', None)
alpha = problem.get('alpha', 1.0)
beta = problem.get('beta', 0.0)
use_fast_acc = problem.get('use_fast_acc', True)
if operation != OperationKindNames[OperationKind.Gemm]:
raise ValueError(f"Unsupported operation {operation}")
if not (len(layout) == 3 and all(c in "nt" for c in layout)):
raise ValueError(f"layout must be a 3-character string containing only 'n' or 't', got {layout}")
layouts = tuple(LayoutType.RowMajor if l == 't' else LayoutType.ColumnMajor for l in layout)
try:
dtype_list = [dtype_a.lower(), dtype_b.lower(), dtype_acc.lower(), dtype_c.lower() if dtype_c is not None else dtype_d.lower(), dtype_d.lower()]
dtypes = tuple(dtype_map[dt] for dt in dtype_list)
except KeyError as dt:
_LOGGER.error(f"Unsupported data type: {dt}")
raise
alignment_a = problem.get('alignment_a', 128 // DataTypeSize[dtypes[0]])
alignment_b = problem.get('alignment_b', 128 // DataTypeSize[dtypes[1]])
configs = get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, beta==0.0, use_fast_acc, count, provider)
problem['configs'] = configs
ret.append(problem)
return ret
def generate_sm100_from_heuristics_configs(manifest, cuda_version, kernel_configs):
"""
Generate CUTLASS operations based on the list of configs provided by the heuristic provider
args:
manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
cuda_version: Cuda compiler version for generating cutlass operations
kernel_configs: list of configs generated by the heuristic
returns:
(configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
"""
min_cc = 100
max_cc = 101
if manifest is None:
# Use a dummy manifest so we can use existing CreateGemmOperator functions
manifest = Manifest()
configs = []
operations = []
for config in kernel_configs:
layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [config['layout_d'], 128 // DataTypeSize[config['dtype_d']]])
element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
# nvMMH assumes 2sm instruction for !(cluster_m % 2)
is_2sm = config['cluster_m'] % 2 == 0
instruction_shape = [(2 * config['cta_tile_m']) if is_2sm else config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k'] // 4]
math_instruction = MathInstruction(
instruction_shape,
element_a, element_b, element_accumulator,
OpcodeClass.TensorOp,
MathOperation.multiply_add
)
data_types = [
{
"a_type" : math_instruction.element_a,
"b_type" : math_instruction.element_b,
"c_type" : DataType.void if config['voidC'] else math_instruction.element_accumulator,
"d_type" : element_d,
"acc_type" : math_instruction.element_accumulator,
"epi_type" : math_instruction.element_accumulator,
}
]
tile_multiplier = (config['cluster_m'] // (2 if is_2sm else 1), config['cluster_n'], config['cluster_k'])
tile_description = TileDescription(
[instruction_shape[0] * tile_multiplier[0],
instruction_shape[1] * tile_multiplier[1],
instruction_shape[2] * 4 * tile_multiplier[2]],
0,
[4,1,1],
math_instruction,
min_cc,
max_cc,
cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
)
schedules = []
if is_2sm:
schedules.append([KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm])
else:
schedules.append([KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm])
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, tile_schedulers=[TileSchedulerType.Default, TileSchedulerType.StreamK], gemm_kind=GemmKind.Universal3x):
configs.append(config)
operations.append(o)
return configs, operations
def generate_sm90_from_heuristics_configs(manifest, cuda_version, kernel_configs):
"""
Generate CUTLASS operations based on the list of configs provided by the heuristic provider
args:
manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
cuda_version: Cuda compiler version for generating cutlass operations
kernel_configs: list of configs generated by the heuristic
returns:
(configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
"""
min_cc, max_cc = 90, 90
if manifest is None:
# Use a dummy manifest so we can use existing CreateGemmOperator functions
manifest = Manifest()
configs = []
operations = []
for config in kernel_configs:
is_aligned = (config['alignment_a'] * DataTypeSize[config['dtype_a']] >= 128) and (config['alignment_b'] * DataTypeSize[config['dtype_b']] >= 128)
layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [LayoutType.ColumnMajor, 1])
element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
# instr shape and warp config are unused for emitting 3x collective builder code
dummy_instr_shape = [0, 0, 0]
math_instruction = MathInstruction(
dummy_instr_shape,
element_a, element_b, element_accumulator,
OpcodeClass.TensorOp,
MathOperation.multiply_add
)
data_types = generate_data_types_from_math_instruction(math_instruction, element_source=element_c, element_dest=element_d)
if is_aligned:
layout = fix_alignments(data_types, layout, alignment_bits=128)
# instr shape and warp config are unused for emitting 3x collective builder code
dummy_warp_count = [0, 0, 0]
tile_description = TileDescription(
[config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k']],
0,
dummy_warp_count,
math_instruction,
min_cc,
max_cc,
cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
)
schedules, stream_k_schedules = get_valid_schedules(
tile_description=tile_description,
cuda_version=cuda_version,
is_aligned=is_aligned,
data_types=data_types,
instantiation_level=9000, # don't prune schedules: we didn't get any schedule suggestion from the heuristic
layout=layout,
gemm_kind=GemmKind.Universal3x,
enable_fp8_fast_acc=config['use_fast_acc']
)
if len(schedules):
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, gemm_kind=GemmKind.Universal3x):
configs.append(config)
operations.append(o)
if len(stream_k_schedules):
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types,
stream_k_schedules,
tile_schedulers=[TileSchedulerType.StreamK]):
configs.append(config)
operations.append(o)
return configs, operations
def filter_manifest_and_write_heuristics_file(manifest, args):
"""
Prune a manifest according to heuristics suggestions from the problems file
args:
manifest: Cutlass manifest to prune
args: generator.py args, requires:
- args.heuristics_problems_file
- args.heuristics_gpu
- args.heuristics_testlist_file
returns:
A list of dictionaries, each of which has information about an operation and a problem from the input problems
"""
heuristics_problems = []
with open(args.heuristics_problems_file, 'r') as f:
heuristics_problems = json.load(f)
gpu = None if (args.heuristics_gpu == "auto" or args.heuristics_gpu == "") else args.heuristics_gpu
mmh = MatmulHeuristics(gpu=gpu)
if any(('100' in arch) for arch in args.architectures.split(';')):
mmh.set_cta_div_n(64)
problems_with_configs = get_gemm_configs(heuristics_problems, provider=mmh, count=args.heuristics_configs_per_problem)
all_configs_and_operations = []
operations = []
for problem in problems_with_configs:
if any('90' in arch for arch in args.architectures.split(';')):
problem_configs, problem_operations = generate_sm90_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
if any(('100' in arch) or ('101' in arch) for arch in args.architectures.split(';')):
problem_configs, problem_operations = generate_sm100_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
operations += problem_operations
problem_without_configs = {k: v for k, v in problem.items() if k != 'configs'}
with_problem_size = [{'operation_name': o.procedural_name(), **problem_without_configs, **c} for c, o in zip(problem_configs, problem_operations)]
all_configs_and_operations += with_problem_size
for operation in operations:
manifest.add_kernel_filter(f"^{operation.procedural_name()}$")
if not all_configs_and_operations:
raise Exception("No valid configurations generated")
write_profiler_testlist_to_csv(all_configs_and_operations, args.heuristics_testlist_file)
return all_configs_and_operations
def write_profiler_testlist_to_csv(configs_list, outfile_path):
"""
Write a list of configs to a testlist to be consumed by cutlass_profiler
args:
configs_list: List of kernel configs along with runtime arguments and any other columns to include in the CSV, expressed as a list of dictionaries
outfile_path: Outfile path
returns:
None
"""
profiler_testlist = configs_list.copy()
for c in profiler_testlist:
for k, v in c.items():
if isinstance(v, DataType):
c[k] = DataTypeNames[v]
elif isinstance(v, LayoutType):
c[k] = ShortLayoutTypeNames[v]
with open(outfile_path, mode='w', newline='') as ofile:
k_names = profiler_testlist[0].keys()
writer = csv.DictWriter(ofile, fieldnames=k_names)
writer.writeheader()
writer.writerows(profiler_testlist)

View File

@ -0,0 +1,168 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Providers for kernel selection heuristics
"""
import sys
import os
import glob
import logging
import ctypes
import functools
from library import DataType, LayoutType
class MatmulHeuristics:
def __init__(self, gpu = None):
import nvMatmulHeuristics
self.mmh_lib = nvMatmulHeuristics
self.gpu = gpu
if 'CUTLASS_NVMMH_SO_PATH' in os.environ:
nvmmhInterfaceEx = functools.partial(self.mmh_lib.NvMatmulHeuristicsInterfaceEx, path=os.environ['CUTLASS_NVMMH_SO_PATH'])
else:
nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx
self.lh = nvmmhInterfaceEx(
backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"],
flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING,
load_discovery_implicitly=True,
gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
)
self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"])
def _layout_from_cutlass(self, layouts):
assert(len(layouts)==3)
full_layout_str = ''.join('t' if l == LayoutType.RowMajor else 'n' for l in layouts)
input_layouts = full_layout_str[:2].upper()
lh_layout = input_layouts + '_' + str("ROW_MAJOR" if full_layout_str[-1]=='t' else "COL_MAJOR")
return self.mmh_lib.NvMatmulHeuristicsMatmulLayout[lh_layout]
def _precision_from_cutlass_dtypes(self, dtypes):
dtype_to_cublas = {
DataType.f64: 'D',
DataType.f32: 'S',
DataType.f16: 'H',
DataType.bf16: 'T',
DataType.e4m3: 'Q',
DataType.e5m2: 'R',
DataType.s32: 'I',
DataType.s8: 'B',
}
dtype_a, dtype_b, dtype_compute, dtype_c, dtype_d = dtypes
a_c = dtype_to_cublas[dtype_a]
if a_c.lower() != 'q':
return a_c + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
else:
return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
def set_cta_div_n(self, div_n):
cta_n_div_requirement = ctypes.c_int(div_n)
self.lh.setBackendValueProperty(
self.backend,
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT,
ctypes.byref(cta_n_div_requirement),
ctypes.sizeof(cta_n_div_requirement)
)
def set_cta_div_m(self, div_m):
cta_m_div_requirement = ctypes.c_int(div_m)
self.lh.setBackendValueProperty(
self.backend,
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT,
ctypes.byref(cta_m_div_requirement),
ctypes.sizeof(cta_m_div_requirement)
)
def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1):
if use_fast_acc:
disable_fast_acc_for_fp8 = ctypes.c_int(0)
else:
disable_fast_acc_for_fp8 = ctypes.c_int(1)
self.lh.setBackendValueProperty(
self.backend,
self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8,
ctypes.byref(disable_fast_acc_for_fp8),
ctypes.sizeof(disable_fast_acc_for_fp8)
)
precision = self._precision_from_cutlass_dtypes(dtypes)
layout = self._layout_from_cutlass(layouts)
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision)
ret = []
for c in configs:
kernel = c['kernel']
problem = c['problem']
r = {}
r['estimated_runtime'] = c['runtime']
r['cta_tile_m'] = kernel.cta_tile_m
r['cta_tile_n'] = kernel.cta_tile_n
r['cta_tile_k'] = kernel.cta_tile_k
r['instr_tile_m'] = kernel.instr_tile_m
r['instr_tile_n'] = kernel.instr_tile_n
r['instr_tile_k'] = kernel.instr_tile_k
r['warp_tile_m'] = kernel.warp_tile_m
r['warp_tile_n'] = kernel.warp_tile_n
r['warp_tile_k'] = kernel.warp_tile_k
r['cluster_m'] = kernel.cluster_m
r['cluster_n'] = kernel.cluster_n
r['cluster_k'] = 1
r['layout_a'] = layouts[0]
r['layout_b'] = layouts[1]
r['layout_d'] = layouts[2]
r['dtype_a'] = dtypes[0]
r['dtype_b'] = dtypes[1]
r['dtype_acc'] = dtypes[2]
r['dtype_c'] = dtypes[3]
r['dtype_d'] = dtypes[4]
r['alignment_a'] = align_a
r['alignment_b'] = align_b
r['swizzle_size'] = kernel.swizzle_factor
r['raster_order'] = 'along_m' if kernel.cta_order==0 else 'along_n'
r['split_k_slices'] = kernel.split_k
r['use_fast_acc'] = use_fast_acc
r['voidC'] = voidC
ret.append(r)
return ret

View File

@ -546,6 +546,22 @@ class KernelScheduleType(enum.Enum):
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
# FP4 Ultra
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto()
Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto()
Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto()
Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto()
@ -603,6 +619,22 @@ KernelScheduleTag = {
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
# FP4 Ultra
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
@ -677,6 +709,21 @@ KernelScheduleSuffixes = {
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_1sm',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_2sm',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_1sm',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_2sm',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_1sm_nopf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_2sm_nopf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_1sm_nopf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_2sm_nopf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_1sm_tmapf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_2sm_tmapf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_1sm_tmapf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_2sm_tmapf',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
@ -713,8 +760,12 @@ class EpilogueScheduleType(enum.Enum):
PtrArrayNoSmemWarpSpecialized = enum_auto()
NoSmemWarpSpecialized1Sm = enum_auto()
NoSmemWarpSpecialized2Sm = enum_auto()
FastF32NoSmemWarpSpecialized1Sm = enum_auto()
FastF32NoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto()
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
TmaWarpSpecialized1Sm = enum_auto()
@ -732,8 +783,12 @@ EpilogueScheduleTag = {
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized',
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm',
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
@ -752,8 +807,12 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem',
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',

View File

@ -526,44 +526,49 @@ class Manifest:
if args.filter_by_cc in ['false', 'False', '0']:
self.filter_by_cc = False
if args.operations == 'all':
self.operations_enabled = []
else:
operations_list = [
OperationKind.Gemm
, OperationKind.Conv2d
, OperationKind.Conv3d
, OperationKind.RankK
, OperationKind.Trmm
, OperationKind.Symm
]
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
if args.operations == 'all':
self.operations_enabled = []
else:
operations_list = [
OperationKind.Gemm
, OperationKind.Conv2d
, OperationKind.Conv3d
, OperationKind.RankK
, OperationKind.Trmm
, OperationKind.Symm
]
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
if args.kernels == 'all':
self.kernel_names = []
else:
self.kernel_names = [x for x in args.kernels.split(',') if x != '']
if args.kernels == 'all':
self.kernel_names = []
else:
self.kernel_names = [x for x in args.kernels.split(',') if x != '']
self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != '']
self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != '']
if args.kernel_filter_file is None:
self.kernel_filter_list = []
else:
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
_LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
filter_count = len(self.kernel_filter_list),
filter_file = args.kernel_filter_file))
if args.kernel_filter_file is None:
self.kernel_filter_list = []
else:
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
_LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
filter_count = len(self.kernel_filter_list),
filter_file = args.kernel_filter_file))
self.operation_count = 0
self.operations_by_name = {}
self.disable_full_archs_compilation = args.disable_full_archs_compilation
self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != ''
self.instantiation_level = 0
try:
self.instantiation_level = int(args.instantiation_level)
except ValueError:
self.instantiation_level = 0
self.operation_count = 0
self.operations_by_name = {}
self.disable_full_archs_compilation = args.disable_full_archs_compilation
self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != ''
self.instantiation_level = 0
try:
self.instantiation_level = int(args.instantiation_level)
except ValueError:
self.instantiation_level = 0
def add_kernel_filter(self, filter_str):
filter_re = re.compile(filter_str)
self.kernel_filter_list.append(filter_re)
def get_sm90_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992):
# Non-negative integer which determines how many kernels are instantiated.

View File

@ -407,7 +407,7 @@ def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level:
def is_tile_desc_compatible_with_cooperative(tile_description):
# Cooperative kernels require a minimum CTA-M of 128
return tile_description.threadblock_shape[0] >= 128
return tile_description.threadblock_shape[0] % 128 == 0
def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types):