Files
cutlass/python/cutlass_library/emit_kernel_listing.py
Junkai-Wu b1d6e2c9b3 v4.3 update. (#2709)
* v4.3 update.

* Update the cute_dsl_api changelog's doc link

* Update version to 4.3.0

* Update the example link

* Update doc to encourage user to install DSL from requirements.txt

---------

Co-authored-by: Larry Wu <larwu@nvidia.com>
2025-10-21 14:26:30 -04:00

868 lines
39 KiB
Python
Executable File

#################################################################################################
#
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
#
#
# \brief Generates the CUTLASS kernel listing with kernel filtering
#
#
###############################################################################
# Example usage:
# generator.py --operations all --generator-target kernel_listing \
# --architectures "70;75;80" --kernels "*" --disable-cutlass-package-imports
###############################################################################
import collections
import csv
import json
import math
import os
try:
import builtins
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
raise ImportError("Disabling attempt to import cutlass_library")
from cutlass_library.library import *
except ImportError:
from library import *
audit_csv_fields = [
"KernelType", "KernelName", "Type_A", "Type_B", "Type_C", "Type_Acc", "Type_EpilogueScale", "Type_D", "Type_SFA", "Type_SFD",
"Layout_A", "Layout_B", "Layout_C", "Layout_D",
"Alignment_A", "Alignment_B", "Alignment_C", "Alignment_D",
"1SM/2SM",
"StreamK Enabled", "Support Runtime_Cluster_Shape", "Support Runtime_Input_Types",
"Test Counts"
]
audit_csv_runtime_fields = [
"KerneIndex", "KernelName",
"Inst_M", "Inst_N", "Inst_K", "Tile_M", "Tile_N", "Tile_K",
"Cluster_M", "Cluster_N", "Cluster_K", "Preferred_Cluster_M", "Preferred_Cluster_N", "Preferred_Cluster_K", "Fallback_Cluster_M", "Fallback_Cluster_N", "Fallback_Cluster_K",
"M", "N", "K", "L", "Alpha_val", "Beta_val",
"Runtime_Input_Types Enabled", "Runtime_Cluster_Shape Enabled"
]
def hash_cutlass_string(input_string):
mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
# Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
output = re.sub(mma_cluster_shape_pattern, "", input_string)
return output
def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b):
# Define a dictionary mapping the detected types to runtime values
datatype_map = {
'f4_f4': runtime_datatype_a + '_' + runtime_datatype_b,
'f4_f6': runtime_datatype_a + '_' + runtime_datatype_b,
'f4_f8': runtime_datatype_a + '_' + runtime_datatype_b,
'f6_f4': runtime_datatype_a + '_' + runtime_datatype_b,
'f6_f6': runtime_datatype_a + '_' + runtime_datatype_b,
'f6_f8': runtime_datatype_a + '_' + runtime_datatype_b,
'f8_f4': runtime_datatype_a + '_' + runtime_datatype_b,
'f8_f6': runtime_datatype_a + '_' + runtime_datatype_b,
'f8_f8': runtime_datatype_a + '_' + runtime_datatype_b,
'ue8m0xf4_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue4m3xf4_ue4m3xf4': 'ue4m3x' + runtime_datatype_a + '_ue4m3x' + runtime_datatype_b,
'ue8m0xf4_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf4_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf6_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf6_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf8_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf8_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf8_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
}
# Regular expression to detect all the keys in datatype_map
pattern = re.compile(r'(' + '|'.join(map(re.escape, datatype_map.keys())) + r')')
# Replace detected patterns using the dictionary
updated_kernel_name = pattern.sub(lambda match: datatype_map[match.group(0)], hashed_kernel_name)
return updated_kernel_name
# This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k.
def get_kernel_features(operation, kernel_name,
dynamic_datatype, runtime_input_datatype):
numcta_inst = "2sm" if "2sm" in kernel_name else "1sm"
math_inst = operation.tile_description.math_instruction
if dynamic_datatype:
dtype_name_A = runtime_input_datatype[0]
dtype_name_B = runtime_input_datatype[1]
else:
dtype_name_A = DataTypeNames[operation.A.element]
dtype_name_B = DataTypeNames[operation.B.element]
layout_name_A = ShortLayoutTypeNames[operation.A.layout]
layout_name_B = ShortLayoutTypeNames[operation.B.layout]
layout_name_C = ShortLayoutTypeNames[operation.C.layout]
layout_name_D = ShortLayoutTypeNames[operation.D.layout]
scale_factor_D_type = operation.ScaleFactorD.element if hasattr(operation, "ScaleFactorD") else DataType.void
scale_factor_A_type = getattr(operation, "ScaleFactorA", DataType.void)
audit_vals = [
"BlockScaledGEMM" if math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp else "GEMM",
kernel_name,
dtype_name_A,
dtype_name_B,
DataTypeNames[operation.C.element],
DataTypeNames[operation.tile_description.math_instruction.element_accumulator],
DataTypeNames[operation.element_epilogue],
DataTypeNames[operation.D.element],
DataTypeNames[scale_factor_D_type],
DataTypeNames[scale_factor_A_type],
layout_name_A,
layout_name_B,
layout_name_C,
layout_name_D,
str(operation.A.alignment),
str(operation.B.alignment),
str(operation.C.alignment),
str(operation.D.alignment),
numcta_inst,
"Y" if 'stream_k' in kernel_name else "N",
]
return audit_vals
# This helper function reports other performance-related kernel parameters and those can be specified at runtime: cluster_shape, instruction shap, m/n/k and alpha/beta.
def get_kernel_params(operation, kernel_name, cluster_shape, fallback_cluster_shape, problem_shape, alpha, beta, dynamic_datatype, dynamic_cluster):
math_inst = operation.tile_description.math_instruction
audit_vals = [
str(math_inst.instruction_shape[0]),
str(math_inst.instruction_shape[1]),
str(math_inst.instruction_shape[2]),
str(operation.tile_description.threadblock_shape[0]),
str(operation.tile_description.threadblock_shape[1]),
str(operation.tile_description.threadblock_shape[2]),
str(operation.tile_description.cluster_shape[0]),
str(operation.tile_description.cluster_shape[1]),
str(operation.tile_description.cluster_shape[2]),
str(cluster_shape[0]),
str(cluster_shape[1]),
str(cluster_shape[2]),
str(fallback_cluster_shape[0]),
str(fallback_cluster_shape[1]),
str(fallback_cluster_shape[2]),
str(problem_shape[0]),
str(problem_shape[1]),
str(problem_shape[2]),
str(problem_shape[3]),
str(alpha),
str(beta),
"Y" if dynamic_datatype else "N",
"Y" if dynamic_cluster else "N",
]
return audit_vals
def _getSubOperationType(kernel):
if kernel.operation_kind == OperationKind.Gemm:
return GemmKindNames[kernel.gemm_kind]
elif kernel.operation_kind == OperationKind.Conv2d:
return "conv_" + ConvKindNames[kernel.conv_kind]
elif kernel.operation_kind == OperationKind.Syrk:
return "syrk_" + SyrkKindNames[kernel.syrk_kind]
elif kernel.operation_kind == OperationKind.Trmm:
return "trmm_" + TrmmKindNames[kernel.trmm_kind]
elif kernel.operation_kind == OperationKind.Symm:
return "symm_" + SymmKindNames[kernel.symm_kind]
else:
raise Exception("Unsupported kernel type")
def _get_inst_shape(math_instruction):
return "".join(str(x) for x in math_instruction.instruction_shape)
def _is_simt_inst(math_instruction):
return _get_inst_shape(math_instruction) in ["111","114"]
def _getInstType(input_precision, accumulate_precision, math_instruction):
# inst_shape
inst_shape = _get_inst_shape(math_instruction)
# input precision
if input_precision == "fp32" and inst_shape != "111":
inp = "tf32"
else:
inp = input_precision
# Handle SIMT op types first
if _is_simt_inst(math_instruction):
simt_input_precision_to_inst = {
"fp32": "FFMA",
"fp64": "DFMA",
"fp16": "HFMA",
"int8": "IDP4A",
}
inst = simt_input_precision_to_inst[input_precision]
else: # Tensor op instructions
if accumulate_precision == "cf64":
fp64_acc_map = {
MathOperation.multiply_add_complex_gaussian : "gz",
MathOperation.multiply_add_complex : "z",
}
acc = fp64_acc_map[math_instruction.math_operation]
else:
tensor_op_acc_map = {
"fp32" : "s",
"cf32" : "s",
"fp16" : "h",
"int32": "i",
"fp64" : "d",
}
acc = tensor_op_acc_map[accumulate_precision]
inst = "{}{}{}".format(acc, inst_shape, inp)
return inst
# TODO: Computes FLOps/Bytes for GEMM - revisit for conv
def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups=1):
assert not (batch_count > 1 and num_groups > 1)
# TODO: adjust for sparsity
gmem_bytes = (
(DataTypeSize[operation.A.element] * m // 8) * k +
(DataTypeSize[operation.B.element] * n // 8) * k +
(DataTypeSize[operation.C.element] * m // 8) * n
)
# TODO: complex-valued support
flops = 2 * (m * n * k)
if bool(beta):
gmem_bytes += (DataTypeSize[operation.C.element] * m // 8) * n
flops += 2 * m * n
multiplier = max(batch_count, num_groups)
gmem_bytes *= multiplier
flops *= multiplier
return flops / gmem_bytes
def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
):
# For functional testing, we prefer to run reference computing on device if any
reference_device_archs = ["100a", "103a"]
run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False
profiler_flags_for_verification = "device" if run_reference_on_device else "host"
# beta values for L0 and L1
# TODO: randomize beta values for wider coverage
beta_values = [0.5]
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"])
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
if (mode == "functional_L0") and is_supported_arch:
problem_waves = [0.5, 1.25, 2.5]
#
# Dense Gemm
#
sm100_mma_data_type_general = [
'gemm_f16_f16_f16_f16_f16',
'gemm_f16_f16_f16_void_f16',
#'gemm_f16_f16_f32_f16_f16',
'tf32gemm_f32_f32_f32_f32_f32',
'bf16gemm_f32_f32_f32_f32_f32',
]
exclude_archs = arch not in ("103a")
if exclude_archs:
sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8')
sm100_mma_data_type_runtime_dtype = [
'gemm.*f4_f4_f32_f32_f32',
'gemm.*f6_f6_f32_f32_f32',
'gemm.*f8_f8_f32_f32_f32',
]
sm100_mma_cluster_size = [
'8x1x1',
'4x4x1', '2x1x1',
'0x0x1' # dynamic cluster
]
# Restrict to two layouts to reduce L0 build and test time.
sm100_mma_layouts = [
'tnt',
'ntn'
]
# regex list must be in kernel procedural name order
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
#
# Block Scale Gemm
#
block_scaled_data_type = [
# runtime datatypes
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
#'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
]
block_scaled_tile_k = ['x128_', 'x256_']
sm103_block_scaled_data_type = [
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
]
sm103_block_scaled_tile_k = ['x768_']
block_scaled_cluster_size = [
'4x4x1', '2x1x1',
'0x0x1' # dynamic cluster
]
block_scaled_layouts = ['tnt']
# regex list must be in kernel procedural name order
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
sm103_block_scaled_prefetch_policy = ['tmapf']
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*"
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*"
if arch in ["100a", "100f"]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch in ["101a", "101f", "110a", "110f"]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch in ["103a"]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})|" \
f"({sm103_block_scaled_filter_regex_1sm})|" \
f"({sm103_block_scaled_filter_regex_2sm})"
elif arch in ["120a", "120f", "121a", "121f"]:
# blockscaled sm120_mma kernels
blockscaled_sm120_mma_kernel_cta_tiles = [
[ '128x128' ]
]
# Restrict to two layouts to reduce L0 build and test time.
blockscaled_sm120_mma_layouts = [ 'tn' ]
filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
problem_waves = [0.5, 1.25, 2.5]
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
else:
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f"
raise Exception(error_message)
elif mode == "functional_L1":
sm100_mma_cluster_size = [
'0x0x1' # dynamic cluster
]
# Restrict to two layouts to reduce L1 build and test time.
sm100_mma_layouts = ['tnt', 'ntn']
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
block_scaled_data_type = [
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
'ue8m0xmx8s26_ue8m0xmx8s26_f32_f16_e5m2',
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
]
sm103_block_scaled_data_type = [
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
]
block_scaled_cluster_size = ['0x0x1']
block_scaled_layouts = ['tnt']
# regex list must be in kernel procedural name order
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})" \
f"({sm103_block_scaled_filter_regex_1sm})|" \
f"({sm103_block_scaled_filter_regex_2sm})"
# CTA tiles for sm120 MMA - only run one tile size to reduce build/test times
sm120_mma_kernel_cta_tiles = [
# h1688, s1688, i16832, i8816
[ '256x128' ],
# d884, c1688,
[ '128x128' ],
# c1688, z884
[ '128x64' ],
# gz884
[ '64x64' ]
]
# sm120 MMA instruction shapes, planar complex type excluded as they are not required
sm120_mma_instruction_shapes = [
[ 'h1688gemm_(?!planar_complex)',
's1688gemm_f16',
's1688gemm_bf16',
's1688gemm_tf32',
'i16832gemm',
'i8816gemm' ],
[ 'd884gemm', 'c1688tf32gemm' ] ,
[ 'c1688gemm',
'z884gemm' ],
[ 'gz884gemm']
]
# It's not pretty, but not sure why different instructions support different tile sizes.
filter_regex_sm120_mma_0 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[0], sm120_mma_kernel_cta_tiles[0]]]) + ").*"
filter_regex_sm120_mma_1 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[1], sm120_mma_kernel_cta_tiles[1]]]) + ").*"
filter_regex_sm120_mma_2 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[2], sm120_mma_kernel_cta_tiles[2]]]) + ").*"
filter_regex_sm120_mma_3 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[3], sm120_mma_kernel_cta_tiles[3]]]) + ").*"
filter_regex_sm120_mma = f"({filter_regex_sm120_mma_0})|({filter_regex_sm120_mma_1})|({filter_regex_sm120_mma_2})|({filter_regex_sm120_mma_3})"
problem_waves = [0.5, 1.25, 2.5]
if arch in ["120a", "120f", "121a", "121f"]:
kernel_filter = f"({filter_regex_sm120_mma})"
else:
kernel_filter = f"({filter_regex_sm100_mma})"
else:
raise ValueError()
outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv")
audit_file_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_SM{arch}_cutlass3x_gemm.csv")
audit_file_params_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_params_SM{arch}_cutlass3x_gemm.csv")
kernel_filter_re = re.compile(kernel_filter)
testcase_counter = 0
kernels_emitted = 0
kernels_total = 0
perf_json_list = []
kernel_name_set = set()
testlist_csv_fields = ["testcase", "metadata"]
testlist_csv_rows = []
auditlist_csv_map = {}
auditlist_csv_params_map = {}
kernel_features = {}
for cc in manifest.operations[OperationKind.Gemm].keys():
for kernel_name, operation_l in manifest.operations[OperationKind.Gemm][cc].items():
assert(len(operation_l) == 1)
kernels_total += 1
if len(kernel_filter_re.findall(kernel_name)) == 0:
continue
# Only test f16 I/O void C kernels in void C kernel set
# Exception: Use void C kernels for more accurate perf testing
if '_void_' in kernel_name and 'perf_' not in mode:
if 'f16_f16_f16_void_f16' not in kernel_name :
continue
kernels_emitted += 1
kernel_name_set.add(kernel_name)
hashed_kernel_name = hash_cutlass_string(kernel_name)
operation = operation_l[0]
dynamic_cluster = (operation.tile_description.cluster_shape[0] == 0
or operation.tile_description.cluster_shape[1] == 0)
dynamic_datatype = "f8" in kernel_name or "f6" in kernel_name or "f4" in kernel_name
runtime_input_datatypes = [None]
if dynamic_datatype:
if "f4_f4" in kernel_name:
runtime_input_datatypes = [['e2m1','e2m1']]
elif "f4_f6" in kernel_name:
runtime_input_datatypes = [['e2m1','e3m2']]
elif "f4_f8" in kernel_name:
runtime_input_datatypes = [['e2m1','e4m3']]
elif "f6_f4" in kernel_name:
runtime_input_datatypes = [['e3m2','e2m1']]
elif "f6_f6" in kernel_name:
runtime_input_datatypes = [['e3m2','e3m2']]
elif "f6_f8" in kernel_name:
runtime_input_datatypes = [['e3m2','e4m3']]
elif "f8_f4" in kernel_name:
runtime_input_datatypes = [['e4m3','e2m1']]
elif "f8_f6" in kernel_name:
runtime_input_datatypes = [['e4m3','e3m2']]
elif "f8_f8" in kernel_name:
runtime_input_datatypes = [
# mask out those not covered in statically encoded test cases
# ['e5m2','e4m3'],
# ['e4m3','e5m2'],
['e4m3','e4m3']
]
# block scaled kernels
elif "ue8m0xf4_ue8m0xf4" in kernel_name:
runtime_input_datatypes = [['e2m1','e2m1']]
elif "ue4m3xf4_ue4m3xf4" in kernel_name:
runtime_input_datatypes = [['e2m1','e2m1']]
elif "ue8m0xf4_ue8m0xf6" in kernel_name:
runtime_input_datatypes = [['e2m1','e2m3']]
elif "ue8m0xf4_ue8m0xf8" in kernel_name:
runtime_input_datatypes = [['e2m1','e4m3']]
elif "ue8m0xf6_ue8m0xf4" in kernel_name:
runtime_input_datatypes = [['e2m3','e2m1']]
elif "ue8m0xf6_ue8m0xf6" in kernel_name:
runtime_input_datatypes = [['e2m3','e2m3']]
elif "ue8m0xf8_ue8m0xf4" in kernel_name:
runtime_input_datatypes = [['e4m3','e2m1']]
elif "ue8m0xf8_ue8m0xf4" in kernel_name:
runtime_input_datatypes = [['e4m3','e2m1']]
elif "ue8m0xf8_ue8m0xf6" in kernel_name:
runtime_input_datatypes = [['e4m3','e2m3']]
elif "ue8m0xf8_ue8m0xf8" in kernel_name:
runtime_input_datatypes = [['e4m3','e4m3']]
if "bstensorop" in kernel_name or is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind):
profiler_flags_for_verification = "host"
# reduce L1 test runtime if reference kernel is not running on device.
if mode == "functional_L1" and profiler_flags_for_verification == "host" :
problem_waves = [0.5, 2.5]
if dynamic_cluster:
if mode == "functional_L0":
runtime_cluster_shapes = [[1,1,1], [2,2,1]]
else:
runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]]
# reduce L1 test runtime if reference kernel is not running on device.
if profiler_flags_for_verification == "host":
runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1]]
cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape
else:
runtime_cluster_shapes = [operation.tile_description.cluster_shape]
cta_tile_shape_m = int(operation.tile_description.threadblock_shape[0] / operation.tile_description.cluster_shape[0])
cta_tile_shape_n = int(operation.tile_description.threadblock_shape[1] / operation.tile_description.cluster_shape[1])
cta_tile_shape_k = int(operation.tile_description.threadblock_shape[2] / operation.tile_description.cluster_shape[2])
alignment_a = operation.A.alignment
alignment_b = operation.B.alignment
alignment_c = operation.C.alignment
alignment_ab_max = max(alignment_a, alignment_b)
layout3x = operation.layout_name_3x()
data_types = operation.datatype_name_3x()
ctas_per_mma_instruction = 1
if '_2sm' in kernel_name:
ctas_per_mma_instruction = 2
valid_cluster_shapes = []
# Remove any cluster shapes that have cluster_m that is not divisible by 2
for cs in runtime_cluster_shapes:
if cs[0] % 2 == 0:
valid_cluster_shapes.append(cs)
runtime_cluster_shapes = valid_cluster_shapes
kernel_problem_waves = problem_waves
if mode == "functional_L0" or mode == "functional_L1":
# for functional testing, we want to perturb just a little from even shapes
# large K = 8 is chosen such that some kernels will warp around their smem buffers, and some will not
# -16 ensures that we are TMA aligned even for FP8/Int8
min_k = alignment_ab_max if cta_tile_shape_k == alignment_ab_max else cta_tile_shape_k - alignment_ab_max
max_k = (cta_tile_shape_k*8) - alignment_ab_max
problem_shapes_k = [min_k, max_k]
sm_count = 16
swizzle_sizes = [0]
# Larger k and less than half wave trigger streamk +separate reduction case to be generated
if 'stream_k' in kernel_name:
problem_shapes_k = [max_k, cta_tile_shape_k*32]
kernel_problem_waves = [0.125, 1.25, 2.5]
else:
raise ValueError
if "void" in kernel_name:
beta_values = [0]
alignment_shift_m = max(alignment_c, alignment_a)
alignment_shift_n = max(alignment_c, alignment_b)
is_first_line = True
for index_waves, waves in enumerate(kernel_problem_waves):
for index_k, k in enumerate(problem_shapes_k):
for beta in beta_values:
for cluster_shape in runtime_cluster_shapes:
for runtime_input_datatype in runtime_input_datatypes:
for swizzle_size in swizzle_sizes:
grid_size = waves * sm_count
cluster_shape_m, cluster_shape_n, cluster_shape_k = tuple(cluster_shape)
if cluster_shape_m >= cluster_shape_n:
grid_m = cluster_shape_m
grid_n = grid_size / grid_m
grid_n = max( int((grid_n + cluster_shape_n - 1) / cluster_shape_n) * cluster_shape_n, 1)
else:
grid_n = cluster_shape_n
grid_m = grid_size / grid_n
grid_m = max( int((grid_m + cluster_shape_m - 1) / cluster_shape_m) * cluster_shape_m, 1)
verification_required = False
if mode == "functional_L0" or mode == "functional_L1":
if '_void_' not in kernel_name:
verification_required = True
m = max(int(grid_m * cta_tile_shape_m), alignment_ab_max)
n = max(int(grid_n * cta_tile_shape_n), alignment_ab_max)
k = int(k)
# For functional testing, we want to perturb just a little from even shapes.
# Only do this if the perturbation does not cause one of the dimensions of the
# problem size to go to zero. This can occur for blockscaling kernels for which
# the alignment requirements for A and B can be quite large (e.g., 256).
if m > alignment_shift_m:
m -= alignment_shift_m
if n > alignment_shift_n:
n -= alignment_shift_n
if '_n32t32_' in kernel_name:
continue
batch_count = 1
if mode == "functional_L0" or mode == "functional_L1" :
if index_waves == 0 and index_k == 0 :
batch_count = 3 if mode == "functional_L0" else 5
gemm_op = "gemm"
grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind)
num_groups = 1
if grouped:
gemm_op = "grouped_gemm"
num_groups = 3 # small to limit test time in host block-scaled reference kernels
batch_count = 1
elif "bstensorop" in kernel_name:
gemm_op = "block_scaled_gemm"
elif is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind):
gemm_op = "blockwise_gemm"
problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)]
assert m > 0 and n > 0 and k > 0
# Emit per-testcase metadata for perf testing usage, eventually in perf database
metadata_dict = {
"input_params": {
'problem_size_category' : problem_size_category,
'operation' : _getSubOperationType(operation),
'datatype' : data_types,
'layout' : layout3x,
'm' : m,
'n' : n,
'k' : k,
'beta' : beta,
'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta, num_groups)
},
"runtime_params": {
'ctas_per_mma_instruction' : ctas_per_mma_instruction,
'tilesize_m' : cta_tile_shape_m,
'tilesize_n' : cta_tile_shape_n,
'tilesize_k' : cta_tile_shape_k,
'cluster_shape_m' : cluster_shape_m,
'cluster_shape_n' : cluster_shape_n,
}
}
cluster_m_fallback = ctas_per_mma_instruction if dynamic_cluster else cluster_shape_m
cluster_n_fallback = 1 if dynamic_cluster else cluster_shape_n
cluster_k_fallback = 1 if dynamic_cluster else cluster_shape_k
if dynamic_datatype:
runtime_datatype_a, runtime_datatype_b = tuple(runtime_input_datatype)
metadata_dict["runtime_params"]["runtime_datatype_a"] = runtime_datatype_a
metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b
testcase_metadata = [
f"cutlass_profiler --operation={gemm_op}" +
(f" --verification-providers=device --providers=cutlass" if profiler_flags_for_verification == "device" else " --mode=trace") +
f" --error-on-no-match --error-if-nothing-is-profiled" +
f" --kernels={kernel_name}" +
f" --m={str(m)}" +
f" --n={str(n)}" +
f" --k={str(k)}" +
(f" --num_groups={str(num_groups)}" if grouped else "") +
f" --cluster_m={str(cluster_shape_m)}" +
f" --cluster_n={str(cluster_shape_n)}" +
f" --cluster_k={str(cluster_shape_k)}" +
f" --cluster_m_fallback={str(cluster_m_fallback)}" +
f" --cluster_n_fallback={str(cluster_n_fallback)}" +
f" --cluster_k_fallback={str(cluster_k_fallback)}" +
f" --beta={str(beta)}" +
("" if grouped else f" --batch_count={str(batch_count)}") +
f" --swizzle_size={str(swizzle_size)}" +
f" --verification-required={str(verification_required).lower()}"
] \
output_dynamic_datatype = dynamic_datatype
if output_dynamic_datatype:
testcase_metadata[0] += (f" --runtime_input_datatype_a={runtime_datatype_a}" +
f" --runtime_input_datatype_b={runtime_datatype_b}")
testcase_metadata.append(json.dumps(metadata_dict))
testlist_csv_rows.append(testcase_metadata)
testcase_counter += 1
alpha = 1.0
if dynamic_datatype:
hashed_kernel_name = transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b)
# If kernel_name is new, initialize its feature set with defaults
if hashed_kernel_name not in kernel_features:
kernel_features[hashed_kernel_name] = {
"is_support_dynamic_cluster": False,
"is_support_dynamic_datatype": False,
}
# Update features for the hashed kernel name
kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] |= dynamic_cluster
kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] |= dynamic_datatype
if hashed_kernel_name not in auditlist_csv_params_map:
auditlist_csv_params_map[hashed_kernel_name] = []
audit_row_params = get_kernel_params(
operation,
hashed_kernel_name,
(cluster_shape_m, cluster_shape_n, cluster_shape_k),
(cluster_m_fallback, cluster_n_fallback, cluster_k_fallback),
(m, n, k, batch_count),
alpha, beta,
dynamic_datatype, dynamic_cluster
)
auditlist_csv_params_map[hashed_kernel_name].append(audit_row_params)
if hashed_kernel_name not in auditlist_csv_map:
audit_row = get_kernel_features(operation, hashed_kernel_name, dynamic_datatype, runtime_input_datatype)
auditlist_csv_map[hashed_kernel_name] = audit_row
with open(outfile_name, 'w') as testlist_csv:
csv_writer = csv.writer(testlist_csv, delimiter=',')
csv_writer.writerow(testlist_csv_fields)
csv_writer.writerows(testlist_csv_rows)
with open(audit_file_name, 'w') as auditlist_csv:
csv_writer = csv.writer(auditlist_csv, delimiter=',')
csv_writer.writerow(audit_csv_fields)
for hashed_kernel_name, row in auditlist_csv_map.items():
# Append the dynamic features as "Y" or "N"
dynamic_cluster_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] else "N"
dynamic_datatype_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] else "N"
test_count = len(auditlist_csv_params_map[hashed_kernel_name])
csv_writer.writerow(row + [dynamic_cluster_flag, dynamic_datatype_flag, test_count])
with open(audit_file_params_name, 'w') as auditlist_csv:
csv_writer = csv.writer(auditlist_csv, delimiter=',')
csv_writer.writerow(audit_csv_runtime_fields)
for kernel_index, (hashed_kernel_name, rows) in enumerate(auditlist_csv_params_map.items(), start=1):
for i, row in enumerate(rows):
if i == 0:
csv_writer.writerow([kernel_index, hashed_kernel_name] + row)
else:
csv_writer.writerow(["", ""] + row)
print(f"Generated a total of {testcase_counter} test cases for {kernels_emitted} kernels out of {kernels_total} total.")
# Generate a newline separated list of kernel filters
assert(len(kernel_name_set) == kernels_emitted)
output_filter_enabled = True
if output_filter_enabled:
kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list")
with open(kernel_filter_outfile_name, "w") as file:
kernel_name_set = set(map(lambda x: x.replace("_epi_tma", ""), kernel_name_set))
for kernel_name in kernel_name_set:
file.write(kernel_name + "\n")
# Sort L0 and L1 kernel list and csv file to avoid mixing cutlass3.x kernels and sm120_mma kernels in cutlass2.x generated together.
if mode == "functional_L0" or mode == "functional_L1":
# Sort the .csv file
outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv")
with open(outfile_name) as file:
data = file.readlines()
data.sort()
with open(outfile_name, 'w') as file:
for i in range(len(data)):
file.write(data[i])
# Sort the kernel list
kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list")
with open(kernel_filter_outfile_name) as file:
data = file.readlines()
data.sort()
with open(kernel_filter_outfile_name, 'w') as file:
for i in range(len(data)):
file.write(data[i])