v4.2 release. (#2587)
* Fix default cluster callback values to 1 to avoid profiler failure when these values are not set in command line. * v4.2 release.
This commit is contained in:
@ -49,7 +49,6 @@ from . import rank_2k_operation
|
||||
from . import rank_k_operation
|
||||
from . import symm_operation
|
||||
from . import trmm_operation
|
||||
|
||||
# Make enum types from library.py accessible via cutlass_library.*
|
||||
from .library import *
|
||||
|
||||
|
||||
@ -279,7 +279,7 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups
|
||||
def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
):
|
||||
# For functional testing, we prefer to run reference computing on device if any
|
||||
reference_device_archs = ["100a"]
|
||||
reference_device_archs = ["100a", "103a"]
|
||||
run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False
|
||||
profiler_flags_for_verification = "device" if run_reference_on_device else "host"
|
||||
|
||||
@ -287,7 +287,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
# TODO: randomize beta values for wider coverage
|
||||
beta_values = [0.5]
|
||||
|
||||
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "120a", "120f"])
|
||||
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"])
|
||||
|
||||
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
|
||||
|
||||
@ -306,6 +306,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'bf16gemm_f32_f32_f32_f32_f32',
|
||||
]
|
||||
|
||||
exclude_archs = arch not in ("103a")
|
||||
if exclude_archs:
|
||||
sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8')
|
||||
|
||||
sm100_mma_data_type_runtime_dtype = [
|
||||
'gemm.*f4_f4_f32_f32_f32',
|
||||
'gemm.*f6_f6_f32_f32_f32',
|
||||
@ -344,6 +348,11 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
sm103_block_scaled_data_type = [
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
||||
]
|
||||
|
||||
block_scaled_cluster_size = [
|
||||
'4x4x1', '2x1x1',
|
||||
'0x0x1' # dynamic cluster
|
||||
@ -354,6 +363,9 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
if arch in ["100a", "100f"]:
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
@ -361,15 +373,23 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch in ["101a", "101f",
|
||||
]:
|
||||
elif arch in ["101a", "101f", "110a", "110f"]:
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch in ["120a", "120f"]:
|
||||
elif arch in ["103a"]:
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})|" \
|
||||
f"({sm103_block_scaled_filter_regex_1sm})|" \
|
||||
f"({sm103_block_scaled_filter_regex_2sm})"
|
||||
elif arch in ["120a", "120f", "121a", "121f"]:
|
||||
|
||||
# blockscaled sm120_mma kernels
|
||||
blockscaled_sm120_mma_kernel_cta_tiles = [
|
||||
@ -384,7 +404,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
|
||||
else:
|
||||
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
|
||||
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f"
|
||||
raise Exception(error_message)
|
||||
|
||||
elif mode == "functional_L1":
|
||||
@ -403,16 +423,27 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1']
|
||||
sm103_block_scaled_data_type = [
|
||||
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
|
||||
]
|
||||
|
||||
block_scaled_cluster_size = ['0x0x1']
|
||||
block_scaled_layouts = ['tnt']
|
||||
|
||||
# regex list must be in kernel procedural name order
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
f"({block_scaled_filter_regex_2sm})" \
|
||||
f"({sm103_block_scaled_filter_regex_1sm})|" \
|
||||
f"({sm103_block_scaled_filter_regex_2sm})"
|
||||
# CTA tiles for sm120 MMA - only run one tile size to reduce build/test times
|
||||
sm120_mma_kernel_cta_tiles = [
|
||||
# h1688, s1688, i16832, i8816
|
||||
@ -449,7 +480,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
problem_waves = [0.5, 1.25, 2.5]
|
||||
|
||||
kernel_filter = f"({filter_regex_sm100_mma})|({filter_regex_sm120_mma})"
|
||||
if arch in ["120a", "120f", "121a", "121f"]:
|
||||
kernel_filter = f"({filter_regex_sm120_mma})"
|
||||
else:
|
||||
kernel_filter = f"({filter_regex_sm100_mma})"
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
|
||||
@ -341,7 +341,7 @@ class GemmOperation:
|
||||
Get the tile shape passed to the collective builder.
|
||||
On Blackwell, this is different than the operation.tile_description.tile_shape.
|
||||
"""
|
||||
is_sm100_kernel = (self.arch == 100)
|
||||
is_sm100_kernel = (self.arch == 100 or self.arch == 103)
|
||||
if not is_sm100_kernel:
|
||||
return self.tile_description.tile_shape
|
||||
|
||||
@ -995,6 +995,24 @@ ${compile_guard_end}
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
|
||||
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
|
||||
|
||||
element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>'
|
||||
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
|
||||
|
||||
|
||||
@ -90,10 +90,12 @@ try:
|
||||
raise ImportError("Disabling attempt to import cutlass_library")
|
||||
from cutlass_library.library import *
|
||||
from cutlass_library.manifest import *
|
||||
from cutlass_library.heuristics import *
|
||||
from cutlass_library.emit_kernel_listing import emit_gemm_kernel_testlist
|
||||
except ImportError:
|
||||
from library import *
|
||||
from manifest import *
|
||||
from heuristics import *
|
||||
from emit_kernel_listing import emit_gemm_kernel_testlist
|
||||
###################################################################################################
|
||||
|
||||
@ -112,6 +114,10 @@ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):
|
||||
cuda_version.append(x)
|
||||
return cuda_version >= [major, minor, patch]
|
||||
|
||||
# From cuda 13.0, Thor SM is renumbered from 101 to 110
|
||||
def ThorSMRenumbering(cuda_version):
|
||||
return 110 if CudaToolkitVersionSatisfies(cuda_version, 13, 0) else 101
|
||||
|
||||
###################################################################################################
|
||||
###################################################################################################
|
||||
|
||||
@ -6768,9 +6774,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
},
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
math_instructions_1sm = [
|
||||
# tf32 -> f32
|
||||
MathInstruction(
|
||||
@ -6887,7 +6895,8 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
[[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
grouped = is_grouped(gemm_kind)
|
||||
@ -7202,9 +7211,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
epi_type = DataType.f32
|
||||
grouped = is_grouped(gemm_kind)
|
||||
|
||||
@ -7889,9 +7900,11 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
TileSchedulerType.Default, TileSchedulerType.StreamK
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = []
|
||||
@ -8092,6 +8105,8 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
|
||||
grouped = is_grouped(gemm_kind)
|
||||
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]],
|
||||
[[LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 128], [LayoutType.RowMajor, 0]],
|
||||
@ -8120,14 +8135,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
def tile_schedulers(sfdtype):
|
||||
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
|
||||
# the epilogue is the traditional linear combination, for which we already have tests with stream-K.
|
||||
if sfdtype["type"] == DataType.void:
|
||||
if sfdtype["type"] == DataType.void or grouped:
|
||||
return [TileSchedulerType.Default]
|
||||
else:
|
||||
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = []
|
||||
@ -8209,6 +8226,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
@ -8246,7 +8273,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
|
||||
for data_type in data_types:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]]
|
||||
[[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]]
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
@ -8288,6 +8315,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
@ -8346,7 +8383,11 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[1][1] != 0:
|
||||
continue
|
||||
|
||||
if math_inst.instruction_shape[0] == 128:
|
||||
if grouped:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
|
||||
[[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]]
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
elif math_inst.instruction_shape[0] == 128:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
|
||||
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]]
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
@ -8396,9 +8437,11 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
else:
|
||||
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = []
|
||||
@ -8496,6 +8539,16 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.bf16,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
@ -8625,6 +8678,16 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.bf16,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
@ -8715,6 +8778,230 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
|
||||
|
||||
|
||||
def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
|
||||
# SM100 MMA with F4 + block scale
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
|
||||
# layouts for ABC and their alignments.
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
instruction_sizes_1sm = [
|
||||
[128, 128, 96],
|
||||
]
|
||||
|
||||
instruction_sizes_2sm = [
|
||||
[256, 128, 96],
|
||||
]
|
||||
|
||||
ab_types = [
|
||||
DataType.e2m1,
|
||||
]
|
||||
|
||||
acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions
|
||||
|
||||
min_cc = 103
|
||||
max_cc = 103
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = []
|
||||
|
||||
is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8)
|
||||
|
||||
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types):
|
||||
is_runtime_datatype_a = is_runtime_datatype(a_type)
|
||||
is_runtime_datatype_b = is_runtime_datatype(b_type)
|
||||
|
||||
# A/B datatypes should be both static or dynamic
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_1sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
a_type, b_type, acc_type,
|
||||
OpcodeClass.BlockScaledTensorOp,
|
||||
MathOperation.multiply_add,
|
||||
DataType.ue8m0) # UE8M0 scale factor
|
||||
)
|
||||
|
||||
math_instructions_2sm = []
|
||||
|
||||
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types):
|
||||
is_runtime_datatype_a = is_runtime_datatype(a_type)
|
||||
is_runtime_datatype_b = is_runtime_datatype(b_type)
|
||||
|
||||
# A/B datatypes should be both static or dynamic
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_2sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
a_type, b_type, acc_type,
|
||||
OpcodeClass.BlockScaledTensorOp,
|
||||
MathOperation.multiply_add,
|
||||
DataType.ue8m0) # UE8M0 scale factor
|
||||
)
|
||||
|
||||
cluster_shapes_1sm = [
|
||||
[1,1,1],
|
||||
# [1,2,1],
|
||||
[2,1,1],
|
||||
# [1,4,1],
|
||||
[4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_1sm:
|
||||
multiplier_1sm = cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_1sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_1sm[1],
|
||||
768],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.bf16,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.e2m1,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor}
|
||||
},
|
||||
]
|
||||
|
||||
for layout in layouts:
|
||||
for data_type in data_types:
|
||||
# Set alignment d based on Destination format.
|
||||
if DataTypeSize[data_type["c_type"]] == 0 :
|
||||
layout[2][1] = 256 // DataTypeSize[data_type["d_type"]]
|
||||
else:
|
||||
layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]])
|
||||
|
||||
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
|
||||
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
|
||||
# E2M1 x E2M1, vector size 32, E8
|
||||
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
|
||||
|
||||
fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized1Sm]
|
||||
fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm]
|
||||
fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm]
|
||||
# For FP4 inputs
|
||||
if isFp4:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch
|
||||
,fp4_schedule_enable_prefetch
|
||||
]
|
||||
, gemm_kind=gemm_kind
|
||||
)
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1],
|
||||
# [2,2,1],
|
||||
# [2,4,1],
|
||||
[4,1,1],
|
||||
# [4,2,1],
|
||||
[4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
multiplier_2sm = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_2sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_2sm[1],
|
||||
math_inst.instruction_shape[2] * 8 * multiplier_2sm[2]],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.bf16,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.e2m1,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor}
|
||||
},
|
||||
]
|
||||
|
||||
for layout in layouts:
|
||||
for data_type in data_types:
|
||||
# Set alignment d based on Destination format.
|
||||
if DataTypeSize[data_type["c_type"]] == 0 :
|
||||
layout[2][1] = 256 // DataTypeSize[data_type["d_type"]]
|
||||
else:
|
||||
layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]])
|
||||
|
||||
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
|
||||
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
|
||||
# E2M1 x E2M1, vector size 32, E8
|
||||
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
|
||||
|
||||
fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
|
||||
fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
|
||||
fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
|
||||
# For FP4 inputs
|
||||
if isFp4:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch
|
||||
,fp4_schedule_enable_prefetch
|
||||
]
|
||||
, gemm_kind=gemm_kind
|
||||
)
|
||||
|
||||
def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
@ -8732,7 +9019,8 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
@ -8948,9 +9236,11 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@ -9074,9 +9364,11 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@ -9200,7 +9492,8 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
@ -9326,9 +9619,11 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@ -9465,9 +9760,11 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@ -9678,9 +9975,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
}
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
math_instructions_1sm = [
|
||||
MathInstruction(
|
||||
[128, 256, 8],
|
||||
@ -9772,9 +10071,11 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
[[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
math_instructions_1sm = [
|
||||
MathInstruction(
|
||||
[128, 256, 16],
|
||||
@ -9934,9 +10235,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = [
|
||||
@ -10084,7 +10387,8 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
||||
return
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
minimum_compute_capability = 100
|
||||
maximum_compute_capability = thor_sm
|
||||
|
||||
@ -10238,7 +10542,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
||||
return
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
minimum_compute_capability = 100
|
||||
maximum_compute_capability = thor_sm
|
||||
|
||||
@ -10422,7 +10727,7 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
|
||||
|
||||
min_cc = 120
|
||||
max_cc = 120
|
||||
max_cc = 121
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
@ -10567,7 +10872,7 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
|
||||
|
||||
min_cc = 120
|
||||
max_cc = 120
|
||||
max_cc = 121
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
@ -10720,7 +11025,7 @@ def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version):
|
||||
return [TileSchedulerType.Default]
|
||||
|
||||
min_cc = 120
|
||||
max_cc = 120
|
||||
max_cc = 121
|
||||
|
||||
kernel_schedules = [
|
||||
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120,
|
||||
@ -10840,7 +11145,7 @@ def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
|
||||
return [TileSchedulerType.Default]
|
||||
|
||||
min_cc = 120
|
||||
max_cc = 120
|
||||
max_cc = 121
|
||||
|
||||
kernel_schedulers = [
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120,
|
||||
@ -10924,7 +11229,11 @@ def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
|
||||
gemm_kind = gemm_kind)
|
||||
|
||||
def GenerateSM100(manifest, cuda_version):
|
||||
arch_family_cc = ['100f', '101f']
|
||||
arch_family_cc = ['100f', '101f', '103a']
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 13, 0):
|
||||
for old_cc, new_cc in [('101f', '110f')]:
|
||||
arch_family_cc = [cc.replace(old_cc, new_cc) for cc in arch_family_cc]
|
||||
|
||||
#
|
||||
# Dense Gemm
|
||||
#
|
||||
@ -10966,8 +11275,11 @@ def GenerateSM100(manifest, cuda_version):
|
||||
# Block Scaled Gemm
|
||||
#
|
||||
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
|
||||
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
|
||||
|
||||
GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
#
|
||||
# Conv
|
||||
#
|
||||
@ -11413,7 +11725,6 @@ def numeric_log_level(log_level: str) -> int:
|
||||
raise ValueError(f'Invalid log level: {log_level}')
|
||||
return numeric_level
|
||||
|
||||
|
||||
# This function for defining the ArgumentParser is used to make it easy for the CUTLASS Python interface
|
||||
# to leverage the functionality in this file without running this script via a shell prompt.
|
||||
def define_parser():
|
||||
@ -11438,6 +11749,11 @@ def define_parser():
|
||||
parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose compute capability range is not satisfied by the build target are excluded.')
|
||||
parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit")
|
||||
parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file')
|
||||
parser.add_argument('--heuristics-problems-file', type=str, default=None, required=False, help='Full path of heuristics problem size description file, as a json list')
|
||||
parser.add_argument('--heuristics-testlist-file', type=str, default=None, required=False, help='Full path of heuristics testlist CSV file, to be passed to cutlass_profiler')
|
||||
parser.add_argument('--heuristics-gpu', type=str, default=None, required=False, help='GPU to use for evaluating heuristics offline. None or `auto` to autodetect using cuda', choices=['', 'auto', 'H100_SXM', 'H100_PCIE', 'H100_NVL', 'H200_SXM', 'H20_SXM', 'B200', 'GB200_NVL', 'RTX_5080', 'RTX_5090', 'RTX_PRO_6000'])
|
||||
parser.add_argument('--heuristics-configs-per-problem', type=int, default=10, required=False, help='Number of kernel configs to generate for each problem in the problem list')
|
||||
parser.add_argument('--heuristics-restrict-kernels', action='store_true', help='Restrict heuristics mode to use only the default set of kernels emitted by generator.py')
|
||||
parser.add_argument('--selected-kernel-list', type=str, default=None, required=False,
|
||||
help='Specify the output log file containing all enabled kernels in this build')
|
||||
parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels")
|
||||
@ -11460,6 +11776,9 @@ if __name__ == "__main__":
|
||||
|
||||
archs = args.architectures.split(';')
|
||||
|
||||
if args.heuristics_problems_file:
|
||||
filter_manifest_and_write_heuristics_file(manifest, args)
|
||||
|
||||
GenerateSM50(manifest, args.cuda_version)
|
||||
GenerateSM60(manifest, args.cuda_version)
|
||||
GenerateSM61(manifest, args.cuda_version)
|
||||
@ -11468,17 +11787,20 @@ if __name__ == "__main__":
|
||||
GenerateSM80(manifest, args.cuda_version)
|
||||
GenerateSM89(manifest, args.cuda_version)
|
||||
GenerateSM90(manifest, args.cuda_version)
|
||||
|
||||
|
||||
blackwell_arch_list = [
|
||||
"100a", "100f",
|
||||
"101a", "101f",
|
||||
"120a", "120f"
|
||||
"103a", "103f",
|
||||
"110a", "110f",
|
||||
"120a", "120f",
|
||||
"121a", "121f",
|
||||
]
|
||||
blackwell_enabled_arch = any(arch in blackwell_arch_list for arch in archs)
|
||||
if blackwell_enabled_arch:
|
||||
GenerateSM100(manifest, args.cuda_version)
|
||||
GenerateSM120(manifest, args.cuda_version)
|
||||
|
||||
|
||||
if 'library' in args.generator_target.split(','):
|
||||
manifest.emit(GeneratorTarget.Library)
|
||||
|
||||
|
||||
414
python/cutlass_library/heuristics.py
Normal file
414
python/cutlass_library/heuristics.py
Normal file
@ -0,0 +1,414 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for selecting CUTLASS library kernels based on problem description
|
||||
"""
|
||||
import json
|
||||
import csv
|
||||
|
||||
try:
|
||||
if CUTLASS_IGNORE_PACKAGE:
|
||||
raise ImportError("Disabling attempt to import cutlass_library")
|
||||
from cutlass_library.library import *
|
||||
from cutlass_library.generator import *
|
||||
from cutlass_library.heuristics_provider import *
|
||||
except ImportError:
|
||||
from library import *
|
||||
from generator import *
|
||||
from heuristics_provider import *
|
||||
|
||||
try:
|
||||
from .sm90_utils import (
|
||||
get_valid_schedules,
|
||||
generate_data_types_from_math_instruction,
|
||||
fix_alignments,
|
||||
)
|
||||
except ImportError:
|
||||
from sm90_utils import (
|
||||
get_valid_schedules,
|
||||
generate_data_types_from_math_instruction,
|
||||
fix_alignments,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
dtype_map = {v: k for k, v in DataTypeNames.items()}
|
||||
|
||||
def serialize_heuristics_results_to_json(problems_with_configs, outfile_path):
|
||||
"""
|
||||
Utilitiy function to write heuristics results to a json file for debug
|
||||
|
||||
args:
|
||||
problems_with_configs: List of problems provided to the heuristic, with a list of operations added to each problem dict
|
||||
outfile_path: Outfile path
|
||||
|
||||
returns:
|
||||
None
|
||||
"""
|
||||
pc_copy = problems_with_configs.copy()
|
||||
for p in pc_copy:
|
||||
for k, v in p.items():
|
||||
if isinstance(v, DataType):
|
||||
p[k] = DataTypeNames[v]
|
||||
elif isinstance(v, LayoutType):
|
||||
p[k] = ShortLayoutTypeNames[v]
|
||||
configs = p['configs']
|
||||
for c in configs:
|
||||
for k, v in c.items():
|
||||
if isinstance(v, DataType):
|
||||
c[k] = DataTypeNames[v]
|
||||
elif isinstance(v, LayoutType):
|
||||
c[k] = ShortLayoutTypeNames[v]
|
||||
with open(outfile_path, 'w') as f:
|
||||
json.dump(pc_copy, f, indent=2)
|
||||
|
||||
def get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, voidC=False, use_fast_acc=True, count=1, provider=None):
|
||||
"""
|
||||
Get heuristic-suggested GEMM kernel configurations for a single GEMM problem.
|
||||
|
||||
args:
|
||||
m, n, k: GEMM dimensions
|
||||
batch_count: batch count
|
||||
layouts: tuple of layouts of type LayoutType
|
||||
use_fast_acc: Use fast accumulation for FP8. Ignored for other precisions
|
||||
count: Number of configs to return
|
||||
provider: Heuristics provider to use
|
||||
|
||||
returns:
|
||||
A list of dictionaries containing the suggested kernel configurations and additional info from the input required to define a Cutlass GemmOperation, with the following keys:
|
||||
- 'cta_tile_m', 'cta_tile_m', 'cta_tile_k': CTA tile size
|
||||
- 'instr_tile_m', 'instr_tile_n', 'instr_tile_k': Instruction tile size
|
||||
- 'stages': kernel pipeline stage count
|
||||
- 'cluster_m', 'cluster_n', 'cluster_k': cluster size
|
||||
- 'layout_a', 'layout_b': input tensor layouts of type LayoutType
|
||||
- 'alignment_a', 'alignment_b': input tensor alignments, in count of elements
|
||||
- 'dtype_a', 'dtype_b', 'dtype_acc': dtypes of a, b, and accumulator, of type DataType
|
||||
- 'swizzle_size' : suggested threadblock swizzle
|
||||
- 'split_k_slices': number of partitions of the k dimension for splitK
|
||||
- 'raster_order': raster order for CTAs over output tiles ('along_m' or 'along_n')
|
||||
"""
|
||||
if provider is None:
|
||||
provider = MatmulHeuristics()
|
||||
return provider.get_configs(m, n, k, batch_count, dtypes, layouts, alignment_a, alignment_b, voidC=voidC, use_fast_acc=use_fast_acc, count=count)
|
||||
|
||||
def get_gemm_configs(problems, provider=None, count=1):
|
||||
"""
|
||||
Get heuristic-suggested GEMM kernel configurations for a set of GEMM problems.
|
||||
|
||||
args:
|
||||
problems: List of dictionaries describing GEMM problems with the following keys:
|
||||
- 'm', 'n', 'k': Matrix dimensions (required)
|
||||
- 'dtype_a': Data type of matrix A (required)
|
||||
- 'dtype_b': Data type of matrix B (required)
|
||||
- 'dtype_c': Data type of matrix C (default: None)
|
||||
- 'dtype_d': Data type of matrix D (required)
|
||||
- 'dtype_acc': Compute data type (default 'f32')
|
||||
- 'layout': Operation layout (e.g. 'tnt')
|
||||
- 'alignment_a': Memory access granularity of A, in units of elements (default: 16 bytes equivalent elements)
|
||||
- 'alignment_b': Memory access granularity of B, in units of elements (default: 16 bytes equivalent elements)
|
||||
- 'alpha': Scalar multiplier for A*B (default: 1.0)
|
||||
- 'beta': Scalar multiplier for C (default: 0.0)
|
||||
- 'batch_count': Number of GEMM operations in batch (default: 1)
|
||||
- 'use_fast_acc': Enable fast accumulation for FP8 on Hopper (default: True)
|
||||
provider: Heuristics provider to use
|
||||
count: Number of configurations to return per problem (defualt: 1)
|
||||
|
||||
returns:
|
||||
A copy of the input dictionary, with key `configs` added containing the selected gemm configs
|
||||
"""
|
||||
ret = []
|
||||
|
||||
for problem in problems:
|
||||
problem = problem.copy()
|
||||
|
||||
try:
|
||||
m = problem['m']
|
||||
n = problem['n']
|
||||
k = problem['k']
|
||||
dtype_a = problem['dtype_a']
|
||||
dtype_b = problem['dtype_b']
|
||||
dtype_d = problem['dtype_d']
|
||||
layout = problem['layout']
|
||||
except KeyError as e:
|
||||
_LOGGER.error(f"Missing required parameter {e} for problem {problem}")
|
||||
raise
|
||||
|
||||
operation = problem.get('operation', 'gemm')
|
||||
batch_count = problem.get('batch_count', 1)
|
||||
dtype_acc = problem.get('dtype_acc', 'f32')
|
||||
dtype_c = problem.get('dtype_c', None)
|
||||
alpha = problem.get('alpha', 1.0)
|
||||
beta = problem.get('beta', 0.0)
|
||||
use_fast_acc = problem.get('use_fast_acc', True)
|
||||
|
||||
if operation != OperationKindNames[OperationKind.Gemm]:
|
||||
raise ValueError(f"Unsupported operation {operation}")
|
||||
if not (len(layout) == 3 and all(c in "nt" for c in layout)):
|
||||
raise ValueError(f"layout must be a 3-character string containing only 'n' or 't', got {layout}")
|
||||
layouts = tuple(LayoutType.RowMajor if l == 't' else LayoutType.ColumnMajor for l in layout)
|
||||
|
||||
try:
|
||||
dtype_list = [dtype_a.lower(), dtype_b.lower(), dtype_acc.lower(), dtype_c.lower() if dtype_c is not None else dtype_d.lower(), dtype_d.lower()]
|
||||
dtypes = tuple(dtype_map[dt] for dt in dtype_list)
|
||||
except KeyError as dt:
|
||||
_LOGGER.error(f"Unsupported data type: {dt}")
|
||||
raise
|
||||
|
||||
alignment_a = problem.get('alignment_a', 128 // DataTypeSize[dtypes[0]])
|
||||
alignment_b = problem.get('alignment_b', 128 // DataTypeSize[dtypes[1]])
|
||||
|
||||
configs = get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, beta==0.0, use_fast_acc, count, provider)
|
||||
problem['configs'] = configs
|
||||
|
||||
ret.append(problem)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def generate_sm100_from_heuristics_configs(manifest, cuda_version, kernel_configs):
|
||||
"""
|
||||
Generate CUTLASS operations based on the list of configs provided by the heuristic provider
|
||||
|
||||
args:
|
||||
manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
|
||||
cuda_version: Cuda compiler version for generating cutlass operations
|
||||
kernel_configs: list of configs generated by the heuristic
|
||||
|
||||
returns:
|
||||
(configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
|
||||
"""
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
if manifest is None:
|
||||
# Use a dummy manifest so we can use existing CreateGemmOperator functions
|
||||
manifest = Manifest()
|
||||
|
||||
configs = []
|
||||
operations = []
|
||||
for config in kernel_configs:
|
||||
layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [config['layout_d'], 128 // DataTypeSize[config['dtype_d']]])
|
||||
element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
|
||||
|
||||
# nvMMH assumes 2sm instruction for !(cluster_m % 2)
|
||||
is_2sm = config['cluster_m'] % 2 == 0
|
||||
instruction_shape = [(2 * config['cta_tile_m']) if is_2sm else config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k'] // 4]
|
||||
math_instruction = MathInstruction(
|
||||
instruction_shape,
|
||||
element_a, element_b, element_accumulator,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add
|
||||
)
|
||||
|
||||
data_types = [
|
||||
{
|
||||
"a_type" : math_instruction.element_a,
|
||||
"b_type" : math_instruction.element_b,
|
||||
"c_type" : DataType.void if config['voidC'] else math_instruction.element_accumulator,
|
||||
"d_type" : element_d,
|
||||
"acc_type" : math_instruction.element_accumulator,
|
||||
"epi_type" : math_instruction.element_accumulator,
|
||||
}
|
||||
]
|
||||
|
||||
tile_multiplier = (config['cluster_m'] // (2 if is_2sm else 1), config['cluster_n'], config['cluster_k'])
|
||||
tile_description = TileDescription(
|
||||
[instruction_shape[0] * tile_multiplier[0],
|
||||
instruction_shape[1] * tile_multiplier[1],
|
||||
instruction_shape[2] * 4 * tile_multiplier[2]],
|
||||
0,
|
||||
[4,1,1],
|
||||
math_instruction,
|
||||
min_cc,
|
||||
max_cc,
|
||||
cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
|
||||
)
|
||||
|
||||
schedules = []
|
||||
if is_2sm:
|
||||
schedules.append([KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm])
|
||||
else:
|
||||
schedules.append([KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm])
|
||||
|
||||
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, tile_schedulers=[TileSchedulerType.Default, TileSchedulerType.StreamK], gemm_kind=GemmKind.Universal3x):
|
||||
configs.append(config)
|
||||
operations.append(o)
|
||||
|
||||
|
||||
return configs, operations
|
||||
|
||||
|
||||
def generate_sm90_from_heuristics_configs(manifest, cuda_version, kernel_configs):
|
||||
"""
|
||||
Generate CUTLASS operations based on the list of configs provided by the heuristic provider
|
||||
|
||||
args:
|
||||
manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
|
||||
cuda_version: Cuda compiler version for generating cutlass operations
|
||||
kernel_configs: list of configs generated by the heuristic
|
||||
|
||||
returns:
|
||||
(configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
|
||||
"""
|
||||
min_cc, max_cc = 90, 90
|
||||
|
||||
if manifest is None:
|
||||
# Use a dummy manifest so we can use existing CreateGemmOperator functions
|
||||
manifest = Manifest()
|
||||
|
||||
configs = []
|
||||
operations = []
|
||||
for config in kernel_configs:
|
||||
|
||||
is_aligned = (config['alignment_a'] * DataTypeSize[config['dtype_a']] >= 128) and (config['alignment_b'] * DataTypeSize[config['dtype_b']] >= 128)
|
||||
layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [LayoutType.ColumnMajor, 1])
|
||||
element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
|
||||
|
||||
# instr shape and warp config are unused for emitting 3x collective builder code
|
||||
dummy_instr_shape = [0, 0, 0]
|
||||
math_instruction = MathInstruction(
|
||||
dummy_instr_shape,
|
||||
element_a, element_b, element_accumulator,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add
|
||||
)
|
||||
|
||||
data_types = generate_data_types_from_math_instruction(math_instruction, element_source=element_c, element_dest=element_d)
|
||||
if is_aligned:
|
||||
layout = fix_alignments(data_types, layout, alignment_bits=128)
|
||||
|
||||
# instr shape and warp config are unused for emitting 3x collective builder code
|
||||
dummy_warp_count = [0, 0, 0]
|
||||
tile_description = TileDescription(
|
||||
[config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k']],
|
||||
0,
|
||||
dummy_warp_count,
|
||||
math_instruction,
|
||||
min_cc,
|
||||
max_cc,
|
||||
cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
|
||||
)
|
||||
|
||||
schedules, stream_k_schedules = get_valid_schedules(
|
||||
tile_description=tile_description,
|
||||
cuda_version=cuda_version,
|
||||
is_aligned=is_aligned,
|
||||
data_types=data_types,
|
||||
instantiation_level=9000, # don't prune schedules: we didn't get any schedule suggestion from the heuristic
|
||||
layout=layout,
|
||||
gemm_kind=GemmKind.Universal3x,
|
||||
enable_fp8_fast_acc=config['use_fast_acc']
|
||||
)
|
||||
|
||||
if len(schedules):
|
||||
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, gemm_kind=GemmKind.Universal3x):
|
||||
configs.append(config)
|
||||
operations.append(o)
|
||||
|
||||
if len(stream_k_schedules):
|
||||
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types,
|
||||
stream_k_schedules,
|
||||
tile_schedulers=[TileSchedulerType.StreamK]):
|
||||
configs.append(config)
|
||||
operations.append(o)
|
||||
|
||||
|
||||
return configs, operations
|
||||
|
||||
def filter_manifest_and_write_heuristics_file(manifest, args):
|
||||
"""
|
||||
Prune a manifest according to heuristics suggestions from the problems file
|
||||
|
||||
args:
|
||||
manifest: Cutlass manifest to prune
|
||||
args: generator.py args, requires:
|
||||
- args.heuristics_problems_file
|
||||
- args.heuristics_gpu
|
||||
- args.heuristics_testlist_file
|
||||
|
||||
returns:
|
||||
A list of dictionaries, each of which has information about an operation and a problem from the input problems
|
||||
"""
|
||||
heuristics_problems = []
|
||||
with open(args.heuristics_problems_file, 'r') as f:
|
||||
heuristics_problems = json.load(f)
|
||||
gpu = None if (args.heuristics_gpu == "auto" or args.heuristics_gpu == "") else args.heuristics_gpu
|
||||
mmh = MatmulHeuristics(gpu=gpu)
|
||||
if any(('100' in arch) for arch in args.architectures.split(';')):
|
||||
mmh.set_cta_div_n(64)
|
||||
problems_with_configs = get_gemm_configs(heuristics_problems, provider=mmh, count=args.heuristics_configs_per_problem)
|
||||
|
||||
all_configs_and_operations = []
|
||||
operations = []
|
||||
for problem in problems_with_configs:
|
||||
if any('90' in arch for arch in args.architectures.split(';')):
|
||||
problem_configs, problem_operations = generate_sm90_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
|
||||
if any(('100' in arch) or ('101' in arch) for arch in args.architectures.split(';')):
|
||||
problem_configs, problem_operations = generate_sm100_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
|
||||
|
||||
operations += problem_operations
|
||||
problem_without_configs = {k: v for k, v in problem.items() if k != 'configs'}
|
||||
with_problem_size = [{'operation_name': o.procedural_name(), **problem_without_configs, **c} for c, o in zip(problem_configs, problem_operations)]
|
||||
all_configs_and_operations += with_problem_size
|
||||
|
||||
for operation in operations:
|
||||
manifest.add_kernel_filter(f"^{operation.procedural_name()}$")
|
||||
if not all_configs_and_operations:
|
||||
raise Exception("No valid configurations generated")
|
||||
write_profiler_testlist_to_csv(all_configs_and_operations, args.heuristics_testlist_file)
|
||||
return all_configs_and_operations
|
||||
|
||||
def write_profiler_testlist_to_csv(configs_list, outfile_path):
|
||||
"""
|
||||
Write a list of configs to a testlist to be consumed by cutlass_profiler
|
||||
|
||||
args:
|
||||
configs_list: List of kernel configs along with runtime arguments and any other columns to include in the CSV, expressed as a list of dictionaries
|
||||
outfile_path: Outfile path
|
||||
|
||||
returns:
|
||||
None
|
||||
"""
|
||||
profiler_testlist = configs_list.copy()
|
||||
for c in profiler_testlist:
|
||||
for k, v in c.items():
|
||||
if isinstance(v, DataType):
|
||||
c[k] = DataTypeNames[v]
|
||||
elif isinstance(v, LayoutType):
|
||||
c[k] = ShortLayoutTypeNames[v]
|
||||
|
||||
with open(outfile_path, mode='w', newline='') as ofile:
|
||||
k_names = profiler_testlist[0].keys()
|
||||
|
||||
writer = csv.DictWriter(ofile, fieldnames=k_names)
|
||||
writer.writeheader()
|
||||
writer.writerows(profiler_testlist)
|
||||
168
python/cutlass_library/heuristics_provider.py
Normal file
168
python/cutlass_library/heuristics_provider.py
Normal file
@ -0,0 +1,168 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Providers for kernel selection heuristics
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import ctypes
|
||||
import functools
|
||||
|
||||
from library import DataType, LayoutType
|
||||
|
||||
class MatmulHeuristics:
|
||||
|
||||
def __init__(self, gpu = None):
|
||||
import nvMatmulHeuristics
|
||||
self.mmh_lib = nvMatmulHeuristics
|
||||
self.gpu = gpu
|
||||
|
||||
if 'CUTLASS_NVMMH_SO_PATH' in os.environ:
|
||||
nvmmhInterfaceEx = functools.partial(self.mmh_lib.NvMatmulHeuristicsInterfaceEx, path=os.environ['CUTLASS_NVMMH_SO_PATH'])
|
||||
else:
|
||||
nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx
|
||||
|
||||
self.lh = nvmmhInterfaceEx(
|
||||
backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"],
|
||||
flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING,
|
||||
load_discovery_implicitly=True,
|
||||
gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
|
||||
)
|
||||
self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"])
|
||||
|
||||
def _layout_from_cutlass(self, layouts):
|
||||
assert(len(layouts)==3)
|
||||
full_layout_str = ''.join('t' if l == LayoutType.RowMajor else 'n' for l in layouts)
|
||||
input_layouts = full_layout_str[:2].upper()
|
||||
lh_layout = input_layouts + '_' + str("ROW_MAJOR" if full_layout_str[-1]=='t' else "COL_MAJOR")
|
||||
return self.mmh_lib.NvMatmulHeuristicsMatmulLayout[lh_layout]
|
||||
|
||||
def _precision_from_cutlass_dtypes(self, dtypes):
|
||||
dtype_to_cublas = {
|
||||
DataType.f64: 'D',
|
||||
DataType.f32: 'S',
|
||||
DataType.f16: 'H',
|
||||
DataType.bf16: 'T',
|
||||
DataType.e4m3: 'Q',
|
||||
DataType.e5m2: 'R',
|
||||
DataType.s32: 'I',
|
||||
DataType.s8: 'B',
|
||||
}
|
||||
|
||||
dtype_a, dtype_b, dtype_compute, dtype_c, dtype_d = dtypes
|
||||
|
||||
a_c = dtype_to_cublas[dtype_a]
|
||||
|
||||
if a_c.lower() != 'q':
|
||||
return a_c + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
|
||||
else:
|
||||
return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
|
||||
|
||||
def set_cta_div_n(self, div_n):
|
||||
cta_n_div_requirement = ctypes.c_int(div_n)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend,
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT,
|
||||
ctypes.byref(cta_n_div_requirement),
|
||||
ctypes.sizeof(cta_n_div_requirement)
|
||||
)
|
||||
|
||||
def set_cta_div_m(self, div_m):
|
||||
cta_m_div_requirement = ctypes.c_int(div_m)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend,
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT,
|
||||
ctypes.byref(cta_m_div_requirement),
|
||||
ctypes.sizeof(cta_m_div_requirement)
|
||||
)
|
||||
|
||||
def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1):
|
||||
if use_fast_acc:
|
||||
disable_fast_acc_for_fp8 = ctypes.c_int(0)
|
||||
else:
|
||||
disable_fast_acc_for_fp8 = ctypes.c_int(1)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend,
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8,
|
||||
ctypes.byref(disable_fast_acc_for_fp8),
|
||||
ctypes.sizeof(disable_fast_acc_for_fp8)
|
||||
)
|
||||
|
||||
precision = self._precision_from_cutlass_dtypes(dtypes)
|
||||
layout = self._layout_from_cutlass(layouts)
|
||||
|
||||
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
|
||||
configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision)
|
||||
|
||||
ret = []
|
||||
for c in configs:
|
||||
kernel = c['kernel']
|
||||
problem = c['problem']
|
||||
|
||||
r = {}
|
||||
r['estimated_runtime'] = c['runtime']
|
||||
r['cta_tile_m'] = kernel.cta_tile_m
|
||||
r['cta_tile_n'] = kernel.cta_tile_n
|
||||
r['cta_tile_k'] = kernel.cta_tile_k
|
||||
r['instr_tile_m'] = kernel.instr_tile_m
|
||||
r['instr_tile_n'] = kernel.instr_tile_n
|
||||
r['instr_tile_k'] = kernel.instr_tile_k
|
||||
r['warp_tile_m'] = kernel.warp_tile_m
|
||||
r['warp_tile_n'] = kernel.warp_tile_n
|
||||
r['warp_tile_k'] = kernel.warp_tile_k
|
||||
r['cluster_m'] = kernel.cluster_m
|
||||
r['cluster_n'] = kernel.cluster_n
|
||||
r['cluster_k'] = 1
|
||||
r['layout_a'] = layouts[0]
|
||||
r['layout_b'] = layouts[1]
|
||||
r['layout_d'] = layouts[2]
|
||||
r['dtype_a'] = dtypes[0]
|
||||
r['dtype_b'] = dtypes[1]
|
||||
r['dtype_acc'] = dtypes[2]
|
||||
r['dtype_c'] = dtypes[3]
|
||||
r['dtype_d'] = dtypes[4]
|
||||
r['alignment_a'] = align_a
|
||||
r['alignment_b'] = align_b
|
||||
r['swizzle_size'] = kernel.swizzle_factor
|
||||
r['raster_order'] = 'along_m' if kernel.cta_order==0 else 'along_n'
|
||||
r['split_k_slices'] = kernel.split_k
|
||||
r['use_fast_acc'] = use_fast_acc
|
||||
r['voidC'] = voidC
|
||||
|
||||
ret.append(r)
|
||||
|
||||
return ret
|
||||
|
||||
@ -546,6 +546,22 @@ class KernelScheduleType(enum.Enum):
|
||||
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
# FP4 Ultra
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto()
|
||||
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto()
|
||||
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto()
|
||||
|
||||
Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto()
|
||||
Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto()
|
||||
Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto()
|
||||
@ -603,6 +619,22 @@ KernelScheduleTag = {
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
|
||||
|
||||
# FP4 Ultra
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103',
|
||||
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
|
||||
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
|
||||
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
|
||||
@ -677,6 +709,21 @@ KernelScheduleSuffixes = {
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
|
||||
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_1sm',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_2sm',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_1sm',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_2sm',
|
||||
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_1sm_nopf',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_2sm_nopf',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_1sm_nopf',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_2sm_nopf',
|
||||
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_1sm_tmapf',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_2sm_tmapf',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_1sm_tmapf',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_2sm_tmapf',
|
||||
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
||||
@ -713,8 +760,12 @@ class EpilogueScheduleType(enum.Enum):
|
||||
PtrArrayNoSmemWarpSpecialized = enum_auto()
|
||||
NoSmemWarpSpecialized1Sm = enum_auto()
|
||||
NoSmemWarpSpecialized2Sm = enum_auto()
|
||||
FastF32NoSmemWarpSpecialized1Sm = enum_auto()
|
||||
FastF32NoSmemWarpSpecialized2Sm = enum_auto()
|
||||
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
|
||||
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
|
||||
PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto()
|
||||
PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto()
|
||||
TmaWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecializedCooperative = enum_auto()
|
||||
TmaWarpSpecialized1Sm = enum_auto()
|
||||
@ -732,8 +783,12 @@ EpilogueScheduleTag = {
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm',
|
||||
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm',
|
||||
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm',
|
||||
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm',
|
||||
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm',
|
||||
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
|
||||
@ -752,8 +807,12 @@ EpilogueScheduleSuffixes = {
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
||||
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
|
||||
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
||||
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
|
||||
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
|
||||
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
|
||||
|
||||
@ -526,44 +526,49 @@ class Manifest:
|
||||
if args.filter_by_cc in ['false', 'False', '0']:
|
||||
self.filter_by_cc = False
|
||||
|
||||
if args.operations == 'all':
|
||||
self.operations_enabled = []
|
||||
else:
|
||||
operations_list = [
|
||||
OperationKind.Gemm
|
||||
, OperationKind.Conv2d
|
||||
, OperationKind.Conv3d
|
||||
, OperationKind.RankK
|
||||
, OperationKind.Trmm
|
||||
, OperationKind.Symm
|
||||
]
|
||||
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
|
||||
if args.operations == 'all':
|
||||
self.operations_enabled = []
|
||||
else:
|
||||
operations_list = [
|
||||
OperationKind.Gemm
|
||||
, OperationKind.Conv2d
|
||||
, OperationKind.Conv3d
|
||||
, OperationKind.RankK
|
||||
, OperationKind.Trmm
|
||||
, OperationKind.Symm
|
||||
]
|
||||
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
|
||||
|
||||
if args.kernels == 'all':
|
||||
self.kernel_names = []
|
||||
else:
|
||||
self.kernel_names = [x for x in args.kernels.split(',') if x != '']
|
||||
if args.kernels == 'all':
|
||||
self.kernel_names = []
|
||||
else:
|
||||
self.kernel_names = [x for x in args.kernels.split(',') if x != '']
|
||||
|
||||
self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
|
||||
self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != '']
|
||||
self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
|
||||
self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != '']
|
||||
|
||||
if args.kernel_filter_file is None:
|
||||
self.kernel_filter_list = []
|
||||
else:
|
||||
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
|
||||
_LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
|
||||
filter_count = len(self.kernel_filter_list),
|
||||
filter_file = args.kernel_filter_file))
|
||||
if args.kernel_filter_file is None:
|
||||
self.kernel_filter_list = []
|
||||
else:
|
||||
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
|
||||
_LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
|
||||
filter_count = len(self.kernel_filter_list),
|
||||
filter_file = args.kernel_filter_file))
|
||||
|
||||
self.operation_count = 0
|
||||
self.operations_by_name = {}
|
||||
self.disable_full_archs_compilation = args.disable_full_archs_compilation
|
||||
self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != ''
|
||||
self.instantiation_level = 0
|
||||
try:
|
||||
self.instantiation_level = int(args.instantiation_level)
|
||||
except ValueError:
|
||||
self.instantiation_level = 0
|
||||
self.operation_count = 0
|
||||
self.operations_by_name = {}
|
||||
self.disable_full_archs_compilation = args.disable_full_archs_compilation
|
||||
self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != ''
|
||||
self.instantiation_level = 0
|
||||
try:
|
||||
self.instantiation_level = int(args.instantiation_level)
|
||||
except ValueError:
|
||||
self.instantiation_level = 0
|
||||
|
||||
def add_kernel_filter(self, filter_str):
|
||||
filter_re = re.compile(filter_str)
|
||||
|
||||
self.kernel_filter_list.append(filter_re)
|
||||
|
||||
def get_sm90_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992):
|
||||
# Non-negative integer which determines how many kernels are instantiated.
|
||||
|
||||
@ -407,7 +407,7 @@ def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level:
|
||||
|
||||
def is_tile_desc_compatible_with_cooperative(tile_description):
|
||||
# Cooperative kernels require a minimum CTA-M of 128
|
||||
return tile_description.threadblock_shape[0] >= 128
|
||||
return tile_description.threadblock_shape[0] % 128 == 0
|
||||
|
||||
|
||||
def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types):
|
||||
|
||||
Reference in New Issue
Block a user