v3.8.0 update (#2082)

* 3.8 update

* fix Markus' name

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-02-06 18:33:40 -08:00
committed by GitHub
parent affd1b693d
commit 833f6990e0
168 changed files with 24945 additions and 3436 deletions

View File

@ -0,0 +1,834 @@
#################################################################################################
#
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
#
#
# \brief Generates the CUTLASS kernel listing with kernel filtering
#
#
###############################################################################
# Example usage:
# generator.py --operations all --generator-target kernel_listing \
# --architectures "70;75;80" --kernels "*" --disable-cutlass-package-imports
###############################################################################
import collections
import csv
import json
import math
import os
try:
import builtins
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
raise ImportError("Disabling attempt to import cutlass_library")
from cutlass_library.library import *
except ImportError:
from library import *
audit_csv_fields = [
"KernelType", "KernelName", "Type_A", "Type_B", "Type_C", "Type_Acc", "Type_EpilogueScale", "Type_D", "Type_SFA", "Type_SFD",
"Layout_A", "Layout_B", "Layout_C", "Layout_D",
"Alignment_A", "Alignment_B", "Alignment_C", "Alignment_D",
"1SM/2SM",
"StreamK Enabled", "Support Runtime_Cluster_Shape", "Support Runtime_Input_Types",
"Test Counts"
]
audit_csv_runtime_fields = [
"KerneIndex", "KernelName",
"Inst_M", "Inst_N", "Inst_K", "Tile_M", "Tile_N", "Tile_K",
"Cluster_M", "Cluster_N", "Cluster_K", "Preferred_Cluster_M", "Preferred_Cluster_N", "Preferred_Cluster_K", "Fallback_Cluster_M", "Fallback_Cluster_N", "Fallback_Cluster_K",
"M", "N", "K", "L", "Alpha_val", "Beta_val",
"Runtime_Input_Types Enabled", "Runtime_Cluster_Shape Enabled"
]
def hash_cutlass_string(input_string):
# Regex pattern to match instruction shape
instruction_shape_pattern = r"[a-zA-Z]\d+x\d+x\d+" # Matches '_s128x128x64', '_h64x128x16', etc.
mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
# Remove instruction shape (e.g., '_s128x128x64', '_h64x128x16')
output = re.sub(instruction_shape_pattern, "", input_string)
# Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
output = re.sub(mma_cluster_shape_pattern, "", output)
return output
def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b):
# Define a dictionary mapping the detected types to runtime values
datatype_map = {
'_f4_': '_' + runtime_datatype_a + '_',
'_f6_': '_' + runtime_datatype_b + '_',
'_f8_': '_' + runtime_datatype_a + '_',
}
# Use regex to identify and replace _f4_, _f6_, or _f8_ in the kernel name
def substitute(match):
datatype = match.group(0) # This is the matched "_f4_", "_f6_", or "_f8_"
return datatype_map.get(datatype, datatype) # Replace or leave as is
# Regex to find "_f4_", "_f6_", or "_f8_" in the hashed_kernel_name
updated_kernel_name = re.sub(r'_f4_|_f6_|_f8_', substitute, hashed_kernel_name)
return updated_kernel_name
# This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k.
def get_kernel_features(operation, kernel_name,
dynamic_datatype, runtime_input_datatype):
numcta_inst = "2sm" if "2sm" in kernel_name else "1sm"
math_inst = operation.tile_description.math_instruction
if dynamic_datatype:
dtype_name_A = runtime_input_datatype[0]
dtype_name_B = runtime_input_datatype[1]
else:
dtype_name_A = DataTypeNames[operation.A.element]
dtype_name_B = DataTypeNames[operation.B.element]
layout_name_A = ShortLayoutTypeNames[operation.A.layout]
layout_name_B = ShortLayoutTypeNames[operation.B.layout]
layout_name_C = ShortLayoutTypeNames[operation.C.layout]
layout_name_D = ShortLayoutTypeNames[operation.D.layout]
scale_factor_D_type = operation.ScaleFactorD.element if hasattr(operation, "ScaleFactorD") else DataType.void
scale_factor_A_type = getattr(operation, "ScaleFactorA", DataType.void)
audit_vals = [
"BlockScaledGEMM" if math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp else "GEMM",
kernel_name,
dtype_name_A,
dtype_name_B,
DataTypeNames[operation.C.element],
DataTypeNames[operation.tile_description.math_instruction.element_accumulator],
DataTypeNames[operation.element_epilogue],
DataTypeNames[operation.D.element],
DataTypeNames[scale_factor_D_type],
DataTypeNames[scale_factor_A_type],
layout_name_A,
layout_name_B,
layout_name_C,
layout_name_D,
str(operation.A.alignment),
str(operation.B.alignment),
str(operation.C.alignment),
str(operation.D.alignment),
numcta_inst,
"Y" if 'stream_k' in kernel_name else "N",
]
return audit_vals
# This helper function reports other performance-related kernel parameters and those can be specified at runtime: cluster_shape, instruction shap, m/n/k and alpha/beta.
def get_kernel_params(operation, kernel_name, cluster_shape, fallback_cluster_shape, problem_shape, alpha, beta, dynamic_datatype, dynamic_cluster):
math_inst = operation.tile_description.math_instruction
audit_vals = [
str(math_inst.instruction_shape[0]),
str(math_inst.instruction_shape[1]),
str(math_inst.instruction_shape[2]),
str(operation.tile_description.threadblock_shape[0]),
str(operation.tile_description.threadblock_shape[1]),
str(operation.tile_description.threadblock_shape[2]),
str(operation.tile_description.cluster_shape[0]),
str(operation.tile_description.cluster_shape[1]),
str(operation.tile_description.cluster_shape[2]),
str(cluster_shape[0]),
str(cluster_shape[1]),
str(cluster_shape[2]),
str(fallback_cluster_shape[0]),
str(fallback_cluster_shape[1]),
str(fallback_cluster_shape[2]),
str(problem_shape[0]),
str(problem_shape[1]),
str(problem_shape[2]),
str(problem_shape[3]),
str(alpha),
str(beta),
"Y" if dynamic_datatype else "N",
"Y" if dynamic_cluster else "N",
]
return audit_vals
def _getSubOperationType(kernel):
if kernel.operation_kind == OperationKind.Gemm:
return GemmKindNames[kernel.gemm_kind]
elif kernel.operation_kind == OperationKind.Conv2d:
return "conv_" + ConvKindNames[kernel.conv_kind]
elif kernel.operation_kind == OperationKind.Syrk:
return "syrk_" + SyrkKindNames[kernel.syrk_kind]
elif kernel.operation_kind == OperationKind.Trmm:
return "trmm_" + TrmmKindNames[kernel.trmm_kind]
elif kernel.operation_kind == OperationKind.Symm:
return "symm_" + SymmKindNames[kernel.symm_kind]
else:
raise Exception("Unsupported kernel type")
def _get_inst_shape(math_instruction):
return "".join(str(x) for x in math_instruction.instruction_shape)
def _is_simt_inst(math_instruction):
return _get_inst_shape(math_instruction) in ["111","114"]
def _getInstType(input_precision, accumulate_precision, math_instruction):
# inst_shape
inst_shape = _get_inst_shape(math_instruction)
# input precision
if input_precision == "fp32" and inst_shape != "111":
inp = "tf32"
else:
inp = input_precision
# Handle SIMT op types first
if _is_simt_inst(math_instruction):
simt_input_precision_to_inst = {
"fp32": "FFMA",
"fp64": "DFMA",
"fp16": "HFMA",
"int8": "IDP4A",
}
inst = simt_input_precision_to_inst[input_precision]
else: # Tensor op instructions
if accumulate_precision == "cf64":
fp64_acc_map = {
MathOperation.multiply_add_complex_gaussian : "gz",
MathOperation.multiply_add_complex : "z",
}
acc = fp64_acc_map[math_instruction.math_operation]
else:
tensor_op_acc_map = {
"fp32" : "s",
"cf32" : "s",
"fp16" : "h",
"int32": "i",
"fp64" : "d",
}
acc = tensor_op_acc_map[accumulate_precision]
inst = "{}{}{}".format(acc, inst_shape, inp)
return inst
# TODO: Computes FLOps/Bytes for GEMM - revisit for conv
def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0):
# TODO: adjust for sparsity
gmem_bytes = (
(DataTypeSize[operation.A.element] * m // 8) * k +
(DataTypeSize[operation.B.element] * n // 8) * k +
(DataTypeSize[operation.C.element] * m // 8) * n
)
# TODO: complex-valued support
flops = 2 * (m * n * k)
if bool(beta):
gmem_bytes += (DataTypeSize[operation.C.element] * m // 8) * n
flops += 2 * m * n
gmem_bytes *= batch_count
flops *= batch_count
return flops / gmem_bytes
def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
):
profiler_reference_computing = "--verification-providers=device --providers=cutlass"
# beta values for L0 and L1
# TODO: randomize beta values for wider coverage
beta_values = [0.5]
is_supported_arch = (arch in ["100a"])
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
if (mode == "functional_L0") and is_supported_arch:
problem_waves = [0.5, 1.25, 2.5]
#
# Dense Gemm
#
sm100_mma_data_type_general = [
'x16gemm_f16_f16_f16_f16_f16',
'x16gemm_f16_f16_f16_void_f16',
'x16gemm_f16_f16_f32_f16_f16',
'x8tf32gemm_f32_f32_f32_f32_f32',
'x16bf16gemm_f32_f32_f32_f32_f32',
]
sm100_mma_data_type_runtime_dtype = [
'x32gemm_f4_f4_f32_f32_f32',
'x32gemm_f6_f6_f32_f32_f32',
'x32gemm_f8_f8_f32_f32_f32',
]
sm100_mma_data_type_mergeable = [
'x32gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification
'x32gemm_e2m1_e2m1_f32_f32_f32',
'x32gemm_e3m2_e3m2_f32_f32_f32',
]
sm100_mma_cluster_size = [
'8x1x1',
'4x4x1', '2x1x1',
'0x0x1' # dynamic cluster
]
# Restrict to two layouts to reduce L0 build and test time.
sm100_mma_layouts = [
'tnt',
'ntn'
]
sm100_mma_instruction_shape = [
# [0] .1CTA, General
['64x128', '128x128', '128x256'],
# [1] .2CTA, General
['128x128', '256x128', '256x256'],
]
# regex list must be in kernel procedural name order
mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
#
# Block Scale Gemm
#
block_scaled_data_type_base = [
# runtime datatypes
'x32gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'x32gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
'x32gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
]
block_scaled_data_type_mergeable = [
'x32gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'x32gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
'x32gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
]
block_scaled_data_type = block_scaled_data_type_base + block_scaled_data_type_mergeable
block_scaled_cluster_size = [
'4x4x1', '2x1x1',
'0x0x1' # dynamic cluster
]
block_scaled_layouts = ['tnt']
block_scaled_instruction_shape = [
# .1CTA
['128x128', '128x192', '128x256'],
# .2CTA
['256x128', '256x192', '256x256'],
]
# regex list must be in kernel procedural name order
mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
if arch == "100a":
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
else:
error_message = "unsupported arch, only support sm100a"
raise Exception(error_message)
# Statically encoded kernels are still added to generated_kernels
# but are filtered out from the testing commands to reduce test duration.
# The mergeable_kernel_filter specifies the kernels that are already covered
# by the runtime datatype tests so that we safely mark them off
# without changing the test coverage.
mergeable_kernel_filter = f"({mergeable_sm100_mma_filter_regex_1sm})|" \
f"({mergeable_sm100_mma_filter_regex_2sm})|" \
f"({mergeable_block_scaled_filter_regex_1sm})|" \
f"({mergeable_block_scaled_filter_regex_2sm})"
elif mode == "functional_L1":
sm100_mma_cluster_size = [
'0x0x1' # dynamic cluster
]
# Restrict to two layouts to reduce L1 build and test time.
sm100_mma_layouts = ['tnt', 'ntn']
sm100_mma_instruction_shape = [
# .1CTA
['64x128', '128x128', '128x256'],
# .2CTA
['128x128', '256x128', '256x256']
]
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
block_scaled_data_type = [
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
'ue8m0xmx8s26_ue8m0xmx8s26_f32_f16_e5m2',
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
]
block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1']
block_scaled_layouts = ['tnt']
block_scaled_instruction_shape = [
# .1CTA
['128x128', '128x192', '128x256'],
# .2CTA
['256x128', '256x192', '256x256'],
]
# regex list must be in kernel procedural name order
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})|"
# CTA tiles for super MMA - only run one tile size to reduce build/test times
supermma_kernel_cta_tiles = [
# h1688, s1688, i16832, i8816
[ '256x128' ],
# d884, c1688,
[ '128x128' ],
# c1688, z884
[ '128x64' ],
# gz884
[ '64x64' ]
]
# super MMA instruction shapes, planar complex type excluded as they are not required
supermma_instruction_shapes = [
[ 'h1688gemm_(?!planar_complex)',
's1688gemm_f16',
's1688gemm_bf16',
's1688gemm_tf32',
'i16832gemm',
'i8816gemm' ],
[ 'd884gemm', 'c1688tf32gemm' ] ,
[ 'c1688gemm',
'z884gemm' ],
[ 'gz884gemm']
]
# It's not pretty, but not sure why different instructions support different tile sizes.
filter_regex_supermma_0 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [supermma_instruction_shapes[0], supermma_kernel_cta_tiles[0]]]) + ").*"
filter_regex_supermma_1 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [supermma_instruction_shapes[1], supermma_kernel_cta_tiles[1]]]) + ").*"
filter_regex_supermma_2 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [supermma_instruction_shapes[2], supermma_kernel_cta_tiles[2]]]) + ").*"
filter_regex_supermma_3 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [supermma_instruction_shapes[3], supermma_kernel_cta_tiles[3]]]) + ").*"
filter_regex_supermma = f"({filter_regex_supermma_0})|({filter_regex_supermma_1})|({filter_regex_supermma_2})|({filter_regex_supermma_3})"
problem_waves = [0.5, 1.25, 2.5]
kernel_filter = f"({filter_regex_sm100_mma})|({filter_regex_supermma})"
else:
raise ValueError()
outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv")
audit_file_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_SM{arch}_cutlass3x_gemm.csv")
audit_file_params_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_params_SM{arch}_cutlass3x_gemm.csv")
if is_runtime_datatype_enabled:
mergeable_kernel_filter_re = re.compile(mergeable_kernel_filter)
kernel_filter_re = re.compile(kernel_filter)
testcase_counter = 0
kernels_emitted = 0
kernels_total = 0
perf_json_list = []
kernel_name_set = set()
testlist_csv_fields = ["testcase", "metadata"]
testlist_csv_rows = []
auditlist_csv_map = {}
auditlist_csv_params_map = {}
kernel_features = {}
for cc in manifest.operations[OperationKind.Gemm].keys():
for kernel_name, operation_l in manifest.operations[OperationKind.Gemm][cc].items():
assert(len(operation_l) == 1)
kernels_total += 1
if len(kernel_filter_re.findall(kernel_name)) == 0:
continue
# Only test f16 I/O void C kernels in void C kernel set
# Exception: Use void C kernels for more accurate perf testing
if '_void_' in kernel_name and 'perf_' not in mode:
if 'f16_f16_f16_void_f16' not in kernel_name :
continue
# Filter out the statically encoded tests which are
# covered by runtime datatype tests to avoid repetition.
if is_runtime_datatype_enabled and len(mergeable_kernel_filter_re.findall(kernel_name)) != 0:
continue
kernels_emitted += 1
kernel_name_set.add(kernel_name)
hashed_kernel_name = hash_cutlass_string(kernel_name)
operation = operation_l[0]
dynamic_cluster = (operation.tile_description.cluster_shape[0] == 0
or operation.tile_description.cluster_shape[1] == 0)
dynamic_datatype = "f8" in kernel_name or "f6" in kernel_name or "f4" in kernel_name
runtime_input_datatypes = [None]
if dynamic_datatype:
if "f4_f4" in kernel_name:
runtime_input_datatypes = [['e2m1','e2m1']]
elif "f4_f6" in kernel_name:
runtime_input_datatypes = [['e2m1','e3m2']]
elif "f4_f8" in kernel_name:
runtime_input_datatypes = [['e2m1','e4m3']]
elif "f6_f4" in kernel_name:
runtime_input_datatypes = [['e3m2','e2m1']]
elif "f6_f6" in kernel_name:
runtime_input_datatypes = [['e3m2','e3m2']]
elif "f6_f8" in kernel_name:
runtime_input_datatypes = [['e3m2','e4m3']]
elif "f8_f4" in kernel_name:
runtime_input_datatypes = [['e4m3','e2m1']]
elif "f8_f6" in kernel_name:
runtime_input_datatypes = [['e4m3','e3m2']]
elif "f8_f8" in kernel_name:
runtime_input_datatypes = [
# mask out those not covered in statically encoded test cases
# ['e5m2','e4m3'],
# ['e4m3','e5m2'],
['e4m3','e4m3']
]
# block scaled kernels
elif "ue8m0xf4_ue8m0xf4" in kernel_name:
runtime_input_datatypes = [['e2m1','e2m1']]
elif "ue4m3xf4_ue4m3xf4" in kernel_name:
runtime_input_datatypes = [['e2m1','e2m1']]
elif "ue8m0xf4_ue8m0xf6" in kernel_name:
runtime_input_datatypes = [['e2m1','e2m3']]
elif "ue8m0xf4_ue8m0xf8" in kernel_name:
runtime_input_datatypes = [['e2m1','e4m3']]
elif "ue8m0xf6_ue8m0xf4" in kernel_name:
runtime_input_datatypes = [['e2m3','e2m1']]
elif "ue8m0xf6_ue8m0xf6" in kernel_name:
runtime_input_datatypes = [['e2m3','e2m3']]
elif "ue8m0xf8_ue8m0xf4" in kernel_name:
runtime_input_datatypes = [['e4m3','e2m1']]
elif "ue8m0xf8_ue8m0xf4" in kernel_name:
runtime_input_datatypes = [['e4m3','e2m1']]
elif "ue8m0xf8_ue8m0xf6" in kernel_name:
runtime_input_datatypes = [['e4m3','e2m3']]
elif "ue8m0xf8_ue8m0xf8" in kernel_name:
runtime_input_datatypes = [['e4m3','e4m3']]
if dynamic_cluster:
if mode == "functional_L0":
runtime_cluster_shapes = [[1,1,1], [2,1,1], [2,2,1], [4,1,1], [4,4,1]]
else:
runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]]
cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape
else:
runtime_cluster_shapes = [operation.tile_description.cluster_shape]
cta_tile_shape_m = int(operation.tile_description.threadblock_shape[0] / operation.tile_description.cluster_shape[0])
cta_tile_shape_n = int(operation.tile_description.threadblock_shape[1] / operation.tile_description.cluster_shape[1])
cta_tile_shape_k = int(operation.tile_description.threadblock_shape[2] / operation.tile_description.cluster_shape[2])
alignment_a = operation.A.alignment
alignment_b = operation.B.alignment
alignment_c = operation.C.alignment
alignment_ab_max = max(alignment_a, alignment_b)
layout3x = operation.layout_name_3x()
data_types = operation.datatype_name_3x()
ctas_per_mma_instruction = 1
if '_2sm' in kernel_name:
ctas_per_mma_instruction = 2
valid_cluster_shapes = []
# Remove any cluster shapes that have cluster_m that is not divisible by 2
for cs in runtime_cluster_shapes:
if cs[0] % 2 == 0:
valid_cluster_shapes.append(cs)
runtime_cluster_shapes = valid_cluster_shapes
kernel_problem_waves = problem_waves
if mode == "functional_L0" or mode == "functional_L1":
# for functional testing, we want to perturb just a little from even shapes
# large K = 8 is chosen such that some kernels will warp around their smem buffers, and some will not
# -16 ensures that we are TMA aligned even for FP8/Int8
min_k = alignment_ab_max if cta_tile_shape_k == alignment_ab_max else cta_tile_shape_k - alignment_ab_max
max_k = (cta_tile_shape_k*8) - alignment_ab_max
problem_shapes_k = [min_k, max_k]
sm_count = 16
# Larger k and less than half wave trigger streamk +separate reduction case to be generated
if 'stream_k' in kernel_name:
problem_shapes_k = [max_k, cta_tile_shape_k*32]
kernel_problem_waves = [0.125, 1.25, 2.5]
else:
raise ValueError
if "void" in kernel_name:
beta_values = [0]
alignment_shift_m = max(alignment_c, alignment_a)
alignment_shift_n = max(alignment_c, alignment_b)
is_first_line = True
for index_waves, waves in enumerate(kernel_problem_waves):
for index_k, k in enumerate(problem_shapes_k):
for beta in beta_values:
for cluster_shape in runtime_cluster_shapes:
for runtime_input_datatype in runtime_input_datatypes:
grid_size = waves * sm_count
cluster_shape_m, cluster_shape_n, cluster_shape_k = tuple(cluster_shape)
if cluster_shape_m >= cluster_shape_n:
grid_m = cluster_shape_m
grid_n = grid_size / grid_m
grid_n = max( int((grid_n + cluster_shape_n - 1) / cluster_shape_n) * cluster_shape_n, 1)
else:
grid_n = cluster_shape_n
grid_m = grid_size / grid_n
grid_m = max( int((grid_m + cluster_shape_m - 1) / cluster_shape_m) * cluster_shape_m, 1)
verification_required = False
if mode == "functional_L0" or mode == "functional_L1":
if '_void_' not in kernel_name:
verification_required = True
m = max(int(grid_m * cta_tile_shape_m), alignment_ab_max)
n = max(int(grid_n * cta_tile_shape_n), alignment_ab_max)
k = int(k)
# For functional testing, we want to perturb just a little from even shapes.
# Only do this if the perturbation does not cause one of the dimensions of the
# problem size to go to zero. This can occur for blockscaling kernels for which
# the alignment requirements for A and B can be quite large (e.g., 256).
if m > alignment_shift_m:
m -= alignment_shift_m
if n > alignment_shift_n:
n -= alignment_shift_n
if '_n32t32_' in kernel_name:
continue
batch_count = 1
if mode == "functional_L0" or mode == "functional_L1" :
if index_waves == 0 and index_k == 0 :
batch_count = 3 if mode == "functional_L0" else 5
gemm_op = "gemm"
profiler_reference_computing_override = profiler_reference_computing
if "bstensorop" in kernel_name:
profiler_reference_computing_override = "--mode=trace"
gemm_op = "block_scaled_gemm"
problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)]
assert m > 0 and n > 0 and k > 0
# Emit per-testcase metadata for perf testing usage, eventually in perf database
metadata_dict = {
"input_params": {
'problem_size_category' : problem_size_category,
'operation' : _getSubOperationType(operation),
'datatype' : data_types,
'layout' : layout3x,
'm' : m,
'n' : n,
'k' : k,
'beta' : beta,
'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta)
},
"runtime_params": {
'ctas_per_mma_instruction' : ctas_per_mma_instruction,
'tilesize_m' : cta_tile_shape_m,
'tilesize_n' : cta_tile_shape_n,
'tilesize_k' : cta_tile_shape_k,
'cluster_shape_m' : cluster_shape_m,
'cluster_shape_n' : cluster_shape_n,
}
}
cluster_m_fallback = ctas_per_mma_instruction if dynamic_cluster else cluster_shape_m
cluster_n_fallback = 1 if dynamic_cluster else cluster_shape_n
cluster_k_fallback = 1 if dynamic_cluster else cluster_shape_k
if dynamic_datatype:
runtime_datatype_a, runtime_datatype_b = tuple(runtime_input_datatype)
metadata_dict["runtime_params"]["runtime_datatype_a"] = runtime_datatype_a
metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b
testcase_metadata = [
f"cutlass_profiler --operation={gemm_op} {profiler_reference_computing_override} --error-on-no-match --error-if-nothing-is-profiled" +
f" --kernels={kernel_name}" +
f" --m={str(m)}" +
f" --n={str(n)}" +
f" --k={str(k)}" +
f" --cluster_m={str(cluster_shape_m)}" +
f" --cluster_n={str(cluster_shape_n)}" +
f" --cluster_k={str(cluster_shape_k)}" +
f" --cluster_m_fallback={str(cluster_m_fallback)}" +
f" --cluster_n_fallback={str(cluster_n_fallback)}" +
f" --cluster_k_fallback={str(cluster_k_fallback)}" +
f" --beta={str(beta)}" +
f" --batch_count={str(batch_count)}" +
f" --verification-required={str(verification_required).lower()}"
] \
output_dynamic_datatype = dynamic_datatype
if output_dynamic_datatype:
testcase_metadata[0] += (f" --runtime_input_datatype_a={runtime_datatype_a}" +
f" --runtime_input_datatype_b={runtime_datatype_b}")
testcase_metadata.append(json.dumps(metadata_dict))
testlist_csv_rows.append(testcase_metadata)
testcase_counter += 1
alpha = 1.0
if dynamic_datatype:
hashed_kernel_name = transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b)
# If kernel_name is new, initialize its feature set with defaults
if hashed_kernel_name not in kernel_features:
kernel_features[hashed_kernel_name] = {
"is_support_dynamic_cluster": False,
"is_support_dynamic_datatype": False,
}
# Update features for the hashed kernel name
kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] |= dynamic_cluster
kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] |= dynamic_datatype
if hashed_kernel_name not in auditlist_csv_params_map:
auditlist_csv_params_map[hashed_kernel_name] = []
audit_row_params = get_kernel_params(
operation,
hashed_kernel_name,
(cluster_shape_m, cluster_shape_n, cluster_shape_k),
(cluster_m_fallback, cluster_n_fallback, cluster_k_fallback),
(m, n, k, batch_count),
alpha, beta,
dynamic_datatype, dynamic_cluster
)
auditlist_csv_params_map[hashed_kernel_name].append(audit_row_params)
if hashed_kernel_name not in auditlist_csv_map:
audit_row = get_kernel_features(operation, hashed_kernel_name, dynamic_datatype, runtime_input_datatype)
auditlist_csv_map[hashed_kernel_name] = audit_row
with open(outfile_name, 'w') as testlist_csv:
csv_writer = csv.writer(testlist_csv, delimiter=',')
csv_writer.writerow(testlist_csv_fields)
csv_writer.writerows(testlist_csv_rows)
with open(audit_file_name, 'w') as auditlist_csv:
csv_writer = csv.writer(auditlist_csv, delimiter=',')
csv_writer.writerow(audit_csv_fields)
for hashed_kernel_name, row in auditlist_csv_map.items():
# Append the dynamic features as "Y" or "N"
dynamic_cluster_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] else "N"
dynamic_datatype_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] else "N"
test_count = len(auditlist_csv_params_map[hashed_kernel_name])
csv_writer.writerow(row + [dynamic_cluster_flag, dynamic_datatype_flag, test_count])
with open(audit_file_params_name, 'w') as auditlist_csv:
csv_writer = csv.writer(auditlist_csv, delimiter=',')
csv_writer.writerow(audit_csv_runtime_fields)
for kernel_index, (hashed_kernel_name, rows) in enumerate(auditlist_csv_params_map.items(), start=1):
for i, row in enumerate(rows):
if i == 0:
csv_writer.writerow([kernel_index, hashed_kernel_name] + row)
else:
csv_writer.writerow(["", ""] + row)
print(f"Generated a total of {testcase_counter} test cases for {kernels_emitted} kernels out of {kernels_total} total.")
# Generate a newline separated list of kernel filters
assert(len(kernel_name_set) == kernels_emitted)
output_filter_enabled = True
if output_filter_enabled:
kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list")
with open(kernel_filter_outfile_name, "w") as file:
kernel_name_set = set(map(lambda x: x.replace("_epi_tma", ""), kernel_name_set))
for kernel_name in kernel_name_set:
file.write(kernel_name + "\n")
# Sort L0 and L1 kernel list and csv file to avoid mixing cutlass3.x kernels and superMMA kernels in cutlass2.x generated together.
if mode == "functional_L0" or mode == "functional_L1":
# Sort the .csv file
outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv")
with open(outfile_name) as file:
data = file.readlines()
data.sort()
with open(outfile_name, 'w') as file:
for i in range(len(data)):
file.write(data[i])
# Sort the kernel list
kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list")
with open(kernel_filter_outfile_name) as file:
data = file.readlines()
data.sort()
with open(kernel_filter_outfile_name, 'w') as file:
for i in range(len(data)):
file.write(data[i])

View File

@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@ -64,7 +64,7 @@ class GemmOperation:
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
tile_scheduler = TileSchedulerType.Default
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False
, ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None
@ -74,6 +74,7 @@ class GemmOperation:
GemmKind.Universal3x,
GemmKind.SparseUniversal3x,
GemmKind.BlockScaledUniversal3x,
GemmKind.GroupedGemmUniversal3x,
}
self.is_3x = gemm_kind in kinds_3x
self.prefix = "3x" if self.is_3x else ""
@ -111,6 +112,12 @@ class GemmOperation:
self.swizzling_functor = swizzling_functor
self.tile_scheduler = tile_scheduler
# Only enable mixed input mode and mixed input shuffle for Hopper
self.mixed_input_mode = None
if self.is_mixed_input() and self.arch >= 90 and self.arch < 100:
self.mixed_input_mode = mixed_input_mode
self.mixed_input_shuffle = (self.mixed_input_mode is not None) and mixed_input_shuffle
#
def is_complex(self):
complex_operators = [
@ -211,6 +218,18 @@ class GemmOperation:
return extended_name
#
def mixed_input_mode_name(self):
mode_name_mapping = {
MixedInputMode.ConvertOnly: "_cvt",
MixedInputMode.ScaleOnly: "_scl",
MixedInputMode.ScaleWithZeroPoint: "_sclzr"
}
mode_name = mode_name_mapping.get(self.mixed_input_mode, "")
if self.mixed_input_shuffle:
mode_name = mode_name + "_shfl"
return mode_name
def extended_name_3x(self):
'''Generates a string representing the MMA atom. Assumes accumulator type is C type.'''
extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
@ -237,6 +256,8 @@ class GemmOperation:
element_d = d_type_names,
core_name = self.core_name())
if self.mixed_input_mode != None:
extended_name = extended_name + self.mixed_input_mode_name()
return extended_name
def datatype_name_3x(self):
@ -768,6 +789,8 @@ using ${operation_name}_epilogue =
${epilogue_functor}
>::CollectiveOp;
${mixed_dtype_prepare_code}
using ${operation_name}_mainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
${arch}, ${opcode_class_main},
@ -782,7 +805,7 @@ using ${operation_name}_mainloop =
// Gemm operator ${operation_name}
using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
${problem_shape},
${operation_name}_mainloop,
${operation_name}_epilogue,
${tile_scheduler}>;
@ -830,7 +853,18 @@ ${compile_guard_end}
return SubstituteTemplate(block_scaled_template, block_scaled_values)
#
@staticmethod
def pointerize_if_grouped(operation, layout):
return layout if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else layout + "* "
@staticmethod
def problem_shape(operation):
gemm_shape_type = "cute::Shape<int,int,int,int>"
grouped_gemm_shape_type = "cute::Shape<int,int,int>"
grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">"
return gemm_shape_type if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else grouped_gemm_shape_type
def emit(self, operation):
_LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)")
_LOGGER.debug("*** operation.procedural_name(): " + operation.procedural_name())
@ -926,17 +960,83 @@ ${compile_guard_end}
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
operation_name_str = operation.procedural_name()
layout_a_str = LayoutTag[instance_layout_A]
layout_b_str = LayoutTag[instance_layout_B]
mixed_dtype_prepare_code = ""
if operation.mixed_input_mode != None:
A_dtype = operation.A.element
B_dtype = operation.B.element
A_dtype_bits = DataTypeSize[A_dtype]
B_dtype_bits = DataTypeSize[B_dtype]
is_A_dtype_narrow = A_dtype_bits < B_dtype_bits
if is_A_dtype_narrow:
narrow_dtype, wide_dtype = (A_dtype, B_dtype)
narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits)
else:
narrow_dtype, wide_dtype = (B_dtype, A_dtype)
narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits)
narrow_tag = DataTypeTag[narrow_dtype]
wide_tag = DataTypeTag[wide_dtype]
scale_tag = DataTypeTag[wide_dtype]
zero_tag = DataTypeTag[wide_dtype]
do_shuffle = False
value_shuffle_str = ""
if narrow_dtype_bits == 4 and wide_dtype_bits == 16:
value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_4>, cute::Stride<cute::_4,cute::_1>>"
do_shuffle = True
if narrow_dtype_bits == 8 and wide_dtype_bits == 16:
value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_2>, cute::Stride<cute::_2,cute::_1>>"
do_shuffle = True
do_shuffle = operation.mixed_input_shuffle and do_shuffle
if do_shuffle:
if is_A_dtype_narrow:
stride_narrow_str = f"cutlass::detail::TagToStrideA_t<{layout_a_str}>"
layout_a_str = f"{operation_name_str}_LayoutNarrowReordered"
else:
stride_narrow_str = f"cutlass::detail::TagToStrideB_t<{layout_b_str}>"
layout_b_str = f"{operation_name_str}_LayoutNarrowReordered"
# The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and
# layout_{a, b}_str are to prevent errors in Windows platform unity build
mixed_dtype_prepare_code = f"""
using {operation_name_str}_StrideNarrow = {stride_narrow_str};
using {operation_name_str}_ValueShuffle = {value_shuffle_str};
static constexpr int {operation_name_str}_NumShuffleAtoms = 1;
using {operation_name_str}_MmaAtomShape = cute::Layout<cute::Shape<cute::_1, cute::Int<{operation_name_str}_NumShuffleAtoms>>>;
using {operation_name_str}_LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, {operation_name_str}_ValueShuffle>());
using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, cute::Layout<cute::Shape<int,int,int>, {operation_name_str}_StrideNarrow>{{}}));
"""
mixed_input_modes_to_element = {
MixedInputMode.ConvertOnly: narrow_tag,
MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>",
MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>"
}
narrow_element = mixed_input_modes_to_element.get(operation.mixed_input_mode, narrow_tag)
if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2):
narrow_element = f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>"
if is_A_dtype_narrow:
element_a = narrow_element
else:
element_b = narrow_element
values = {
'operation_name': operation.procedural_name(),
'operation_name': operation_name_str,
'operation_suffix': self.operation_suffix,
'problem_shape': self.problem_shape(operation),
'element_a': element_a,
'layout_a': LayoutTag[instance_layout_A],
'layout_a': self.pointerize_if_grouped(operation, layout_a_str),
'element_b': element_b,
'layout_b': LayoutTag[instance_layout_B],
'layout_b': self.pointerize_if_grouped(operation, layout_b_str),
'element_c': DataTypeTag[operation.C.element],
'layout_c': LayoutTag[instance_layout_C],
'layout_c': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_C]),
'element_d': DataTypeTag[operation.D.element],
'layout_d': LayoutTag[instance_layout_D],
'layout_d': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_D]),
'element_accumulator': DataTypeTag[operation.accumulator_type()],
'opcode_class_main': OpcodeClassTag[opcode_class_main],
'opcode_class_epi': OpcodeClassTag[opcode_class_epi],
@ -968,6 +1068,7 @@ ${compile_guard_end}
'epilogue_vector_length': str(epilogue_vector_length),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]),
'mixed_dtype_prepare_code': mixed_dtype_prepare_code
}
return SubstituteTemplate(self.gemm_template, values)
@ -1294,7 +1395,8 @@ class EmitGemmConfigurationLibrary:
GemmKind.BlockScaledUniversal3x: EmitGemmUniversal3xInstance,
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
GemmKind.Grouped: EmitGemmGroupedInstance
GemmKind.Grouped: EmitGemmGroupedInstance,
GemmKind.GroupedGemmUniversal3x: EmitGemmUniversal3xInstance,
}
self.gemm_kind_wrappers = {
@ -1306,7 +1408,8 @@ class EmitGemmConfigurationLibrary:
GemmKind.BlockScaledUniversal3x: 'BlockScaledGemmUniversal3xOperation',
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation',
GemmKind.Grouped: 'GemmGroupedOperation'
GemmKind.Grouped: 'GemmGroupedOperation',
GemmKind.GroupedGemmUniversal3x: 'GroupedGemmUniversal3xOperation'
}
self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
@ -1363,6 +1466,7 @@ void initialize_${configuration_name}(Manifest &manifest) {
("library_internal.h", None),
("gemm_operation.h", None),
("gemm_operation_3x.hpp", None),
("grouped_gemm_operation_3x.hpp", None),
("sparse_gemm_operation_3x.hpp", None),
("block_scaled_gemm_operation_3x.hpp", None),
("cutlass/arch/wmma.h", None),

View File

@ -90,9 +90,11 @@ try:
raise ImportError("Disabling attempt to import cutlass_library")
from cutlass_library.library import *
from cutlass_library.manifest import *
from cutlass_library.emit_kernel_listing import emit_gemm_kernel_testlist
except ImportError:
from library import *
from manifest import *
from emit_kernel_listing import emit_gemm_kernel_testlist
###################################################################################################
#
@ -177,7 +179,8 @@ def CreateGemmUniversal3xOperator(
complex_transforms=None,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity1,
tile_schedulers=[TileSchedulerType.Default]):
tile_schedulers=[TileSchedulerType.Default],
gemm_kind=GemmKind.Universal3x):
if type(data_types) is dict:
data_types = [data_types]
@ -206,7 +209,6 @@ def CreateGemmUniversal3xOperator(
D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1])
gemm_op_extra_args = {}
gemm_kind = GemmKind.Universal3x
element_compute = data_type.get("epi_type", data_type["acc_type"])
@ -218,16 +220,43 @@ def CreateGemmUniversal3xOperator(
gemm_kind = GemmKind.BlockScaledUniversal3x
operation = GemmOperation(
gemm_kind, tile_description.minimum_compute_capability,
tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D,
kernel_schedule, epilogue_schedule, tile_scheduler, **gemm_op_extra_args)
A_dtype = data_type["a_type"]
B_dtype = data_type["b_type"]
A_dtype_bits = DataTypeSize[A_dtype]
B_dtype_bits = DataTypeSize[B_dtype]
is_A_dtype_narrow = A_dtype_bits < B_dtype_bits
if is_A_dtype_narrow:
narrow_dtype, wide_dtype = (A_dtype, B_dtype)
narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits)
else:
narrow_dtype, wide_dtype = (B_dtype, A_dtype)
narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits)
manifest.append(operation)
operations.append(operation)
mixed_input_modes = [None]
if narrow_dtype_bits != wide_dtype_bits:
if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2):
mixed_input_modes = [MixedInputMode.ScaleOnly]
else:
mixed_input_modes = [MixedInputMode.ConvertOnly, MixedInputMode.ScaleOnly, MixedInputMode.ScaleWithZeroPoint]
mixed_input_shuffle_options = [False]
if (mixed_input_modes[0] is not None) and (wide_dtype_bits == 16) and (narrow_dtype_bits == 4 or narrow_dtype_bits == 8):
mixed_input_shuffle_options = [False, True]
for mixed_input_mode, mixed_input_shuffle in product(mixed_input_modes, mixed_input_shuffle_options):
operation = GemmOperation(
gemm_kind, tile_description.minimum_compute_capability,
tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D,
kernel_schedule, epilogue_schedule, tile_scheduler,
mixed_input_mode=mixed_input_mode, mixed_input_shuffle=mixed_input_shuffle, **gemm_op_extra_args)
manifest.append(operation)
operations.append(operation)
return operations
def is_grouped(gemm_kind):
return gemm_kind == GemmKind.GroupedGemmUniversal3x
# Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts
def CreateSparseGemmUniversal3xOperator(
manifest, layouts, tile_descriptions, data_types,
@ -4934,12 +4963,7 @@ def GenerateSM80(manifest, cuda_version):
###################################################################################################
def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version):
if (
not CudaToolkitVersionSatisfies(cuda_version, 12, 4)
):
return
def GenerateSM89_TensorOp_16832_fp8(manifest, element_acc):
layouts = [
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor)
@ -4948,49 +4972,48 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version):
math_instructions = [
MathInstruction(
[16, 8, 32],
DataType.e4m3, DataType.e4m3, DataType.f32,
DataType.e4m3, DataType.e4m3, element_acc,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[16, 8, 32],
DataType.e4m3, DataType.e5m2, DataType.f32,
DataType.e4m3, DataType.e5m2, element_acc,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[16, 8, 32],
DataType.e5m2, DataType.e4m3, DataType.f32,
DataType.e5m2, DataType.e4m3, element_acc,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[16, 8, 32],
DataType.e5m2, DataType.e5m2, DataType.f32,
DataType.e5m2, DataType.e5m2, element_acc,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
MathInstruction(
[16, 8, 32],
DataType.e4m3, DataType.e4m3, DataType.f32,
DataType.e4m3, DataType.e4m3, element_acc,
OpcodeClass.TensorOp,
MathOperation.multiply_add_fast_accum),
MathInstruction(
[16, 8, 32],
DataType.e4m3, DataType.e5m2, DataType.f32,
DataType.e4m3, DataType.e5m2, element_acc,
OpcodeClass.TensorOp,
MathOperation.multiply_add_fast_accum),
MathInstruction(
[16, 8, 32],
DataType.e5m2, DataType.e4m3, DataType.f32,
DataType.e5m2, DataType.e4m3, element_acc,
OpcodeClass.TensorOp,
MathOperation.multiply_add_fast_accum),
MathInstruction(
[16, 8, 32],
DataType.e5m2, DataType.e5m2, DataType.f32,
DataType.e5m2, DataType.e5m2, element_acc,
OpcodeClass.TensorOp,
MathOperation.multiply_add_fast_accum),
]
min_cc = 89
max_cc = 89
max_cc = 100
alignment_constraints = [16,]
alignment_constraints_small_channels = [16, 8, 4]
@ -5077,6 +5100,18 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version):
else:
op.C.alignment = 8
def GenerateSM89_TensorOp_16832_fp8_fp32acc(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 4):
return
GenerateSM89_TensorOp_16832_fp8(manifest, DataType.f32)
def GenerateSM89_TensorOp_16832_fp8_fp16acc(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
GenerateSM89_TensorOp_16832_fp8(manifest, DataType.f16)
#
def GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version):
@ -5177,7 +5212,8 @@ def GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version):
#
def GenerateSM89(manifest, cuda_version):
GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version)
GenerateSM89_TensorOp_16832_fp8_fp32acc(manifest, cuda_version)
GenerateSM89_TensorOp_16832_fp8_fp16acc(manifest, cuda_version)
GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version)
###################################################################################################
@ -5189,6 +5225,7 @@ try:
generate_tf32_math_instructions_sm90,
generate_int8_math_instructions_sm90,
generate_fp8_math_instructions_sm90,
generate_mixed_dtype_math_instructions_sm90,
make_sparse_math_instructions,
generate_tile_descriptions_sm90,
get_valid_schedules,
@ -5201,6 +5238,7 @@ except ImportError:
generate_tf32_math_instructions_sm90,
generate_int8_math_instructions_sm90,
generate_fp8_math_instructions_sm90,
generate_mixed_dtype_math_instructions_sm90,
make_sparse_math_instructions,
generate_tile_descriptions_sm90,
get_valid_schedules,
@ -5208,8 +5246,8 @@ except ImportError:
fix_alignments,
)
def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0):
return
instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992)
@ -5262,10 +5300,11 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
data_types=data_type,
instantiation_level=instantiation_level,
layout=layout,
gemm_kind=gemm_kind,
)
if len(schedules):
CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules)
CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind)
if len(stream_k_schedules):
assert CudaToolkitVersionSatisfies(cuda_version, 12, 1)
CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type,
@ -5728,8 +5767,8 @@ def GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version):
tile_schedulers=[TileSchedulerType.StreamK])
def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0):
return
instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992)
@ -5783,10 +5822,11 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
data_types=data_type,
instantiation_level=instantiation_level,
layout=layout,
gemm_kind=gemm_kind,
)
if len(schedules):
CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules)
CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind)
if len(stream_k_schedules):
assert CudaToolkitVersionSatisfies(cuda_version, 12, 1)
CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type,
@ -5851,6 +5891,90 @@ def GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version):
stream_k_schedules,
tile_schedulers=[TileSchedulerType.StreamK])
def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 1):
return
instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9999)
is_aligned = True
# layouts for ABC, their alignments will be fixed later based on the data type
layouts = [
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]],
]
valid_types_for_a_b_acc = [
(DataType.e4m3, DataType.f16, DataType.f32),
(DataType.e4m3, DataType.bf16, DataType.f32),
(DataType.e5m2, DataType.f16, DataType.f32),
(DataType.e5m2, DataType.bf16, DataType.f32),
(DataType.s8, DataType.f16, DataType.f32),
(DataType.s8, DataType.bf16, DataType.f32),
(DataType.u8, DataType.f16, DataType.f32),
(DataType.u8, DataType.bf16, DataType.f32),
(DataType.s4, DataType.f16, DataType.f32),
(DataType.s4, DataType.bf16, DataType.f32),
(DataType.s4, DataType.e4m3, DataType.f32),
(DataType.s4, DataType.e5m2, DataType.f32),
(DataType.u4, DataType.f16, DataType.f32),
(DataType.u4, DataType.bf16, DataType.f32),
(DataType.u2, DataType.f16, DataType.f32),
(DataType.u2, DataType.bf16, DataType.f32),
(DataType.s2, DataType.f16, DataType.f32),
(DataType.s2, DataType.bf16, DataType.f32),
]
# Note: For sizeof(a_type) > sizeof(b_type), some generated kernels might crash due to a compiler bug. Disable it for now.
#swapped_valid_types_for_a_b_acc = [(b_type, a_type, acc_type) for a_type, b_type, acc_type in valid_types_for_a_b_acc]
#valid_types_for_a_b_acc = valid_types_for_a_b_acc + swapped_valid_types_for_a_b_acc
math_instructions = generate_mixed_dtype_math_instructions_sm90(instantiation_level, valid_types_for_a_b_acc)
valid_types_for_d = [DataType.f32]
valid_types_for_c = [DataType.f32]
tile_descriptions = generate_tile_descriptions_sm90(
math_instructions=math_instructions,
is_aligned=is_aligned,
level=instantiation_level)
for tile_desc in tile_descriptions:
math_inst = tile_desc.math_instruction
data_types = []
for c_type, d_type in product(valid_types_for_c, valid_types_for_d):
data_types.append(
generate_data_types_from_math_instruction(
math_inst,
element_source=c_type,
element_dest=d_type,
)
)
for layout in layouts:
for data_type in data_types:
# Fix alignments, DataTypeSize are in the unit of bits
alignment_bits = 128
layout[0][1] = alignment_bits // DataTypeSize[data_type['a_type']]
layout[1][1] = alignment_bits // DataTypeSize[data_type['b_type']]
layout[2][1] = alignment_bits // DataTypeSize[data_type['c_type']]
schedules, stream_k_schedules = get_valid_schedules(
tile_description=tile_desc,
cuda_version=cuda_version,
is_aligned=is_aligned,
data_types=data_type,
instantiation_level=instantiation_level,
layout=layout,
)
if len(schedules):
CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules)
if len(stream_k_schedules):
assert CudaToolkitVersionSatisfies(cuda_version, 12, 1)
CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type,
stream_k_schedules,
tile_schedulers=[TileSchedulerType.StreamK])
def GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 2):
@ -6662,7 +6786,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types,
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers)
def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version):
def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
@ -6680,6 +6804,8 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version):
min_cc = 100
max_cc = 100
grouped = is_grouped(gemm_kind)
math_instructions_1sm = [
# f16 -> f16
#MathInstruction(
@ -6736,6 +6862,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version):
MathOperation.multiply_add)]
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1],[4,4,1]
, DynamicClusterShape
]
tile_schedulers = [
@ -6776,9 +6903,11 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version):
for layout in layouts:
layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]]
kernel_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types,
[[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]],
tile_schedulers=tile_schedulers)
[[kernel_schedule, epi_schedule]],
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
# for mixed precision kernels, also generate kernels that write output matrix in the A/B format
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
@ -6806,8 +6935,8 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version):
layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]]
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed,
[[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]],
tile_schedulers=tile_schedulers)
[[kernel_schedule, epi_schedule]],
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
# 2xSM MMA kernels
math_instructions_2sm = [
@ -6886,6 +7015,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version):
MathOperation.multiply_add)]
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
@ -6921,13 +7051,16 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version):
for layout in layouts:
layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]]
if math_inst.instruction_shape[0] == 128:
if grouped:
epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm
elif math_inst.instruction_shape[0] == 128:
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm
else:
epi_schedule = EpilogueScheduleType.ScheduleAuto
kernel_schedule = KernelScheduleType.TmaWarpSpecialized2SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types,
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers)
[[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
# for mixed precision kernels, also generate kernels that write output matrix in the A/B format
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
@ -6955,9 +7088,9 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version):
layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]]
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed,
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers)
[[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version):
def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
@ -6976,6 +7109,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version):
min_cc = 100
max_cc = 100
epi_type = DataType.f32
grouped = is_grouped(gemm_kind)
math_instructions_1sm = [
# inst 64x128
@ -7038,6 +7172,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version):
MathOperation.multiply_add)]
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1]
, DynamicClusterShape
]
tile_schedulers = [
@ -7163,9 +7298,14 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version):
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
( data_type["d_type"] == DataType.e5m2 ):
continue
# don't support runtime data type for grouped yet
if grouped and (data_type["a_type"] == DataType.f8 or data_type["b_type"] == DataType.f8):
continue
kernel_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]],
tile_schedulers=tile_schedulers)
[[kernel_schedule, epi_schedule]],
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
# 2xSM MMA kernels
math_instructions_2sm = [
@ -7241,6 +7381,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version):
]
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
@ -7361,15 +7502,20 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version):
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
( data_type["d_type"] == DataType.e5m2 ):
continue
# don't support runtime data type for grouped yet
if grouped and (data_type["a_type"] == DataType.f8 or data_type["b_type"] == DataType.f8):
continue
if math_inst.instruction_shape[0] == 128:
if grouped:
epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm
elif math_inst.instruction_shape[0] == 128:
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm
else:
epi_schedule = EpilogueScheduleType.ScheduleAuto
kernel_schedule = KernelScheduleType.TmaWarpSpecialized2SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers)
[[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version):
@ -7460,6 +7606,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
[2,1,1],
# [1,4,1],
[4,4,1]
, DynamicClusterShape
]
# 1xSM MMA kernels
@ -7533,6 +7680,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
[4,1,1],
# [4,2,1],
[4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
@ -7728,6 +7876,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
[2,1,1],
# [1,4,1],
[4,4,1]
, DynamicClusterShape
]
# 1xSM MMA kernels
@ -7841,6 +7990,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
[4,1,1],
# [4,2,1],
[4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
@ -8419,6 +8569,9 @@ def GenerateSM100(manifest, cuda_version):
GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version)
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version)
# grouped GEMM
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
#
# Block Scaled Gemm
#
@ -8800,7 +8953,10 @@ def GenerateSM90(manifest, cuda_version):
GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version)
GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version)
GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version)
GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version)
GenerateSM90_TensorOp_1684(manifest, cuda_version)
GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
GenerateSM90_TensorOp_1684_complex(manifest, cuda_version)
GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version)
GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version)
@ -8899,6 +9055,12 @@ if __name__ == "__main__":
if 'library' in args.generator_target.split(','):
manifest.emit(GeneratorTarget.Library)
if 'kernel_testlist_l0' in args.generator_target.split(','):
emit_gemm_kernel_testlist(manifest, args.curr_build_dir, args.architectures, "functional_L0")
if 'kernel_testlist_l1' in args.generator_target.split(','):
emit_gemm_kernel_testlist(manifest, args.curr_build_dir, args.architectures, "functional_L1")
if args.selected_kernel_list is not None:
if len(manifest.selected_kernels) > 0:
with open(args.selected_kernel_list, 'w') as file_writer:

View File

@ -485,19 +485,25 @@ class KernelScheduleType(enum.Enum):
TmaWarpSpecialized1SmSm100 = enum_auto()
TmaWarpSpecialized2SmSm100 = enum_auto()
PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto()
PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto()
BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto()
BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto()
Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
Mxf4TmaWarpSpecialized1SmSm100 = enum_auto()
Mxf4TmaWarpSpecialized2SmSm100 = enum_auto()
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
KernelPtrArrayTmaWarpSpecializedCooperative = enum_auto()
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
KernelPtrArrayTmaWarpSpecializedPingpong = enum_auto()
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
#
KernelScheduleTag = {
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
@ -516,19 +522,24 @@ KernelScheduleTag = {
KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100',
KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100',
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100',
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100',
KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100',
KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100',
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100',
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100',
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
}
#
@ -549,39 +560,54 @@ KernelScheduleSuffixes = {
KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm',
KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm',
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: '_1sm',
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: '_2sm',
KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: '_1sm',
KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm',
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
}
class EpilogueScheduleType(enum.Enum):
ScheduleAuto = enum_auto()
EpilogueTransposed = enum_auto()
NoSmemWarpSpecialized = enum_auto()
PtrArrayNoSmemWarpSpecialized = enum_auto()
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
TmaWarpSpecialized1Sm = enum_auto()
TmaWarpSpecialized2Sm = enum_auto()
PtrArrayTmaWarpSpecialized1Sm = enum_auto()
PtrArrayTmaWarpSpecialized2Sm = enum_auto()
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
PtrArrayTmaWarpSpecializedCooperative = enum_auto()
#
EpilogueScheduleTag = {
EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto',
EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed',
EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized',
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
EpilogueScheduleType.TmaWarpSpecialized2Sm: 'cutlass::epilogue::TmaWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong',
}
#
@ -589,10 +615,15 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.ScheduleAuto: '',
EpilogueScheduleType.EpilogueTransposed: '',
EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem',
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma',
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '_tma_1sm',
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_tma_2sm',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma_cooperative',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma_pingpong',
}
class EpilogueFunctor3x(enum.Enum):
@ -786,6 +817,7 @@ class GemmKind(enum.Enum):
PlanarComplexArray = enum_auto()
Grouped = enum_auto()
BlockScaledUniversal3x = enum_auto()
GroupedGemmUniversal3x = enum_auto()
#
GemmKindNames = {
@ -797,7 +829,8 @@ GemmKindNames = {
GemmKind.PlanarComplex: "gemm_planar_complex",
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
GemmKind.Grouped: "gemm_grouped",
GemmKind.BlockScaledUniversal3x: "gemm_block_scaled"
GemmKind.BlockScaledUniversal3x: "gemm_block_scaled",
GemmKind.GroupedGemmUniversal3x: "gemm_grouped",
}
#
@ -838,6 +871,12 @@ EpilogueFunctorTag = {
EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
}
#
class MixedInputMode(enum.Enum):
ConvertOnly = enum_auto()
ScaleOnly = enum_auto()
ScaleWithZeroPoint = enum_auto()
#
class SwizzlingFunctor(enum.Enum):
Identity1 = enum_auto()

View File

@ -43,7 +43,7 @@ import os.path
import shutil
import sys
import copy
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Optional, Sequence, Tuple, List
try:
import builtins
@ -153,6 +153,17 @@ def generate_int8_math_instruction_shapes_sm90(level: int):
]
return filtered_list_of_wgmma_shapes
def generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level: int, a_type: DataType, b_type: DataType):
# DataTypeSize are in the unit of bits
a_bytes = DataTypeSize[a_type] // 8
b_bytes = DataTypeSize[b_type] // 8
if a_bytes == 4 or b_bytes == 4:
return generate_tf32_math_instruction_shapes_sm90(wgmma_level)
elif a_bytes == 2 or b_bytes == 2:
return generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level)
else:
return generate_fp8_math_instruction_shapes_sm90(wgmma_level)
###########
def generate_tf32_math_instructions_sm90(level: int):
@ -219,6 +230,22 @@ def generate_fp8_math_instructions_sm90(level: int):
]
return math_instructions
def generate_mixed_dtype_math_instructions_sm90(level: int, types_of_a_b_acc: List[Tuple[DataType, DataType, DataType]]):
wgmma_level = get_wgmma_level_from_global_level(level)
math_instructions = []
for a_type, b_type, acc_type in types_of_a_b_acc:
math_instruction_shapes = generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level, a_type, b_type)
for math_instruction_shape in math_instruction_shapes:
math_instructions += [
MathInstruction(
math_instruction_shape,
a_type, b_type, acc_type,
OpcodeClass.TensorOp,
MathOperation.multiply_add
),
]
return math_instructions
def generate_int8_math_instructions_sm90(level: int):
wgmma_level = get_wgmma_level_from_global_level(level)
math_instructions = []
@ -407,7 +434,7 @@ def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types):
def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, layout,
instantiation_level, enable_fp8_fast_acc=True):
instantiation_level, enable_fp8_fast_acc=True, gemm_kind=GemmKind.Universal3x):
# Level 0: prune according to existing generator.py behavior
# Level >= 1: no pruning
level = get_pruning_level_from_global_level(instantiation_level)
@ -428,8 +455,6 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
is_fp32 = data_types["a_type"] in FP32_TYPES and data_types["b_type"] in FP32_TYPES
requires_transposed_epilogue = is_fp32 and layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.RowMajor
is_sparse = tile_description.math_instruction.opcode_class == OpcodeClass.SparseTensorOp
can_do_cooperative = is_tile_desc_compatible_with_cooperative(tile_description)
can_do_tma_epilogue = is_aligned and not requires_transposed_epilogue and can_tile_desc_use_shmem_in_epilogue(tile_description, data_types)
@ -464,6 +489,16 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
if is_fp32 and (is_tn or is_nn) and (cta_n % cta_k != 0):
return [], []
grouped = gemm_kind == GemmKind.GroupedGemmUniversal3x
if grouped:
# the following cases are unsupported by grouped GEMM
if not is_aligned:
return [], []
if not can_do_tma_epilogue:
return [], []
if requires_transposed_epilogue:
return [], []
# Early pruning
if level < 1:
# Don't stamp out FP16/BF16 kernels smaller than or equal to 64x128x64
@ -477,20 +512,23 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
if not is_void_c or d_type not in FP8_TYPES:
return [], []
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
return [
schedules = []
if not grouped:
schedules.append(
[
KernelScheduleType.TmaWarpSpecializedCooperative,
EpilogueScheduleType.TmaWarpSpecializedCooperative
])
schedules.append(
[
KernelScheduleType.TmaWarpSpecializedCooperative,
EpilogueScheduleType.TmaWarpSpecializedCooperative
],
[
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecializedCooperative
],
] , []
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum if not grouped else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecializedCooperative if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
])
return schedules, []
return [], []
if is_fp8 and not is_large_fp8_tile:
valid_dtypes_for_c = [DataType.f32, DataType.bf16, DataType.f16]
valid_dtypes_for_c = [DataType.f32, DataType.bf16, DataType.f16, DataType.void]
# Prune all configs with fp8 source, and all configs with non-fp8 output
# that have different dtypes for source and output.
if c_type not in valid_dtypes_for_c or (d_type not in FP8_TYPES and c_type != d_type):
@ -504,6 +542,33 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
if is_void_c and not can_do_tma_epilogue:
return [], []
# For mixed input data types
a_type_size = DataTypeSize[data_types["a_type"]]
b_type_size = DataTypeSize[data_types["b_type"]]
if a_type_size != b_type_size and CudaToolkitVersionSatisfies(cuda_version, 12, 1):
schedules = []
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
if a_type_size > b_type_size:
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
schedules.append([
KernelScheduleType.TmaWarpSpecialized,
epilogue_schedule
])
schedules.append([
KernelScheduleType.TmaWarpSpecializedPingpong,
epilogue_schedule
])
if cta_m >= 128:
if a_type_size > b_type_size:
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
else:
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative
schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
epilogue_schedule
])
return schedules, []
if not is_aligned:
schedules = [[KernelScheduleType.CpAsyncWarpSpecialized,
default_epilogue]]
@ -521,6 +586,15 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
return schedules, stream_k_schedules
if grouped:
pingpong = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum
cooperative = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum
if can_do_tma_epilogue:
schedules.append([pingpong, EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong])
if can_do_cooperative:
schedules.append([cooperative, EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative])
return schedules, []
schedules = []
# Pruning: emit Void-C kernels with persistent kernels only
if level >= 1 or not is_void_c: