update 3.8 v2 (#2112)

* update 3.8 v2

* update 3.8

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-02-19 19:03:14 -08:00
committed by GitHub
parent e9627ce55b
commit b84e9802d8
166 changed files with 3986 additions and 4037 deletions

View File

@ -532,29 +532,12 @@ def tuple_factory_(input_tuple, dtype, constants=[0,1]):
if first_non_empty_base is None:
first_non_empty_base = []
# Determine whether or not add an additional byte for empty base classes
additional_byte = False
# Special case for constant tuple
if first_non_empty_base is None:
additional_byte = False
else:
for base in first_non_empty_base:
if base in empty_bases:
additional_byte = True
break
if additional_byte:
ctype_fields = [("empty_byte", EmptyByte), ] + ctype_fields
# Create the ctype tuple
class TupleType(ctypes.Structure):
_fields_ = ctype_fields
def __init__(self, args) -> None:
if additional_byte:
fields = self._fields_[1:]
else:
fields = self._fields_
fields = self._fields_
assert len(fields) == len(args)
for field, arg in zip(fields, args):

View File

@ -69,7 +69,7 @@ using ${operation_name}_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
${arch},
${opcode_class_epi},
${output_cta_tile_shape}, // output cta tile shape
${mma_tile_shape}, // mma tile shape
${cluster_shape}, // cluster shape
${epi_tile_mn},
${element_accumulator},
@ -109,26 +109,6 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
def arch_number_to_type(self, arch: int) -> str:
return f"cutlass::arch::Sm{arch}"
def output_cta_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
# For all three kinds of convolutions, the tile shape's K mode
# differs from GEMM in that needs to be wrapped in a Shape.
# For Wgrad convolutions specifically,
# the N tile shape also needs to be wrapped in a Shape.
m_template = 'cute::_${cta_m}'
if operation.conv_kind == ConvKind.Wgrad:
n_template = 'cute::Shape<cute::_${cta_n}>'
else:
n_template = 'cute::_${cta_n}'
k_template = 'cute::Shape<cute::_${cta_k}>'
output_cta_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
values = {
'cta_m': cta_m,
'cta_n': cta_n,
'cta_k': cta_k
}
return Template(output_cta_tile_shape_template).substitute(values)
def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
mma_m = cta_m
mma_n = cta_n
@ -223,7 +203,6 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
'element_accumulator': DataTypeTag[operation.accumulator_type()],
'opcode_class': opcode_class,
'arch': self.arch_number_to_type(operation.arch),
'output_cta_tile_shape': self.output_cta_tile_shape(operation, cta_m, cta_n, cta_k),
'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k),
'cluster_shape': self.cluster_shape(operation),
'opcode_class_epi': opcode_class_epi,

View File

@ -90,19 +90,32 @@ def hash_cutlass_string(input_string):
def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b):
# Define a dictionary mapping the detected types to runtime values
datatype_map = {
'_f4_': '_' + runtime_datatype_a + '_',
'_f6_': '_' + runtime_datatype_b + '_',
'_f8_': '_' + runtime_datatype_a + '_',
'f4_f4': runtime_datatype_a + '_' + runtime_datatype_b,
'f4_f6': runtime_datatype_a + '_' + runtime_datatype_b,
'f4_f8': runtime_datatype_a + '_' + runtime_datatype_b,
'f6_f4': runtime_datatype_a + '_' + runtime_datatype_b,
'f6_f6': runtime_datatype_a + '_' + runtime_datatype_b,
'f6_f8': runtime_datatype_a + '_' + runtime_datatype_b,
'f8_f4': runtime_datatype_a + '_' + runtime_datatype_b,
'f8_f6': runtime_datatype_a + '_' + runtime_datatype_b,
'f8_f8': runtime_datatype_a + '_' + runtime_datatype_b,
'ue8m0xf4_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue4m3xf4_ue4m3xf4': 'ue4m3x' + runtime_datatype_a + '_ue4m3x' + runtime_datatype_b,
'ue8m0xf4_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf4_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf6_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf6_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf8_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf8_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
'ue8m0xf8_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
}
# Use regex to identify and replace _f4_, _f6_, or _f8_ in the kernel name
def substitute(match):
datatype = match.group(0) # This is the matched "_f4_", "_f6_", or "_f8_"
return datatype_map.get(datatype, datatype) # Replace or leave as is
# Regular expression to detect all the keys in datatype_map
pattern = re.compile(r'(' + '|'.join(map(re.escape, datatype_map.keys())) + r')')
# Replace detected patterns using the dictionary
updated_kernel_name = pattern.sub(lambda match: datatype_map[match.group(0)], hashed_kernel_name)
# Regex to find "_f4_", "_f6_", or "_f8_" in the hashed_kernel_name
updated_kernel_name = re.sub(r'_f4_|_f6_|_f8_', substitute, hashed_kernel_name)
return updated_kernel_name
# This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k.

View File

@ -64,17 +64,15 @@ class GemmOperation:
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False
, ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None
):
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False,
ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None):
kinds_3x = {
GemmKind.Universal3x,
GemmKind.SparseUniversal3x,
GemmKind.BlockScaledUniversal3x,
GemmKind.GroupedGemmUniversal3x,
GemmKind.GroupedUniversal3x,
GemmKind.GroupedBlockScaledUniversal3x,
}
self.is_3x = gemm_kind in kinds_3x
self.prefix = "3x" if self.is_3x else ""
@ -87,13 +85,11 @@ class GemmOperation:
self.C = C
self.D = D
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
if is_block_scaled(gemm_kind):
self.ScaleFactorA = ScaleFactorA
self.ScaleFactorB = ScaleFactorB
self.ScaleFactorD = ScaleFactorD["tensor"]
self.ScaleFactorVectorSize = ScaleFactorD["vector_size"]
if self.D == None:
self.D = self.C
@ -239,13 +235,13 @@ class GemmOperation:
element_c = DataTypeNames[self.C.element],
element_d = DataTypeNames[self.D.element],
core_name = self.core_name())
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
if is_block_scaled(self.gemm_kind):
d_type_names = DataTypeNames[self.D.element]
if self.ScaleFactorD.element != DataType.void:
d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names
extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format(
element_sfa = DataTypeNames[self.ScaleFactorA],
element_a = DataTypeNames[self.A.element],
@ -255,7 +251,7 @@ class GemmOperation:
element_c = DataTypeNames[self.C.element],
element_d = d_type_names,
core_name = self.core_name())
if self.mixed_input_mode != None:
extended_name = extended_name + self.mixed_input_mode_name()
return extended_name
@ -298,8 +294,8 @@ class GemmOperation:
# Generates a short string representing underlying epilogue schedule type
def epilogue_schedule_name_3x(self):
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
if is_block_scaled(self.gemm_kind):
if self.ScaleFactorD.element != DataType.void:
return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout]
@ -779,7 +775,7 @@ class EmitGemmUniversal3xInstance:
using ${operation_name}_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
${arch}, ${opcode_class_epi},
cute::Shape<cute::_${tile_shape_epi_m}, cute::_${tile_shape_epi_n}, cute::_${tile_shape_epi_k}>,
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
${epi_tile_mn},
${element_accumulator}, ${element_epilogue},
@ -797,7 +793,7 @@ using ${operation_name}_mainloop =
${element_a}, ${layout_a}, ${align_a},
${element_b}, ${layout_b}, ${align_b},
${element_accumulator},
cute::Shape<cute::_${tile_shape_main_m}, cute::_${tile_shape_main_n}, cute::_${tile_shape_main_k}>,
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
${stages},
${kernel_schedule}
@ -855,7 +851,7 @@ ${compile_guard_end}
@staticmethod
def pointerize_if_grouped(operation, layout):
return layout if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else layout + "* "
return layout if not is_grouped(operation.gemm_kind) else layout + "* "
@staticmethod
def problem_shape(operation):
@ -863,7 +859,7 @@ ${compile_guard_end}
grouped_gemm_shape_type = "cute::Shape<int,int,int>"
grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">"
return gemm_shape_type if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else grouped_gemm_shape_type
return gemm_shape_type if not is_grouped(operation.gemm_kind) else grouped_gemm_shape_type
def emit(self, operation):
_LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)")
@ -874,18 +870,12 @@ ${compile_guard_end}
opcode_class_main = operation.tile_description.math_instruction.opcode_class
opcode_class_epi = opcode_class_main
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
if operation.epilogue_schedule != EpilogueScheduleType.NoSmemWarpSpecialized:
opcode_class_epi = OpcodeClass.TensorOp
tile_shape = operation.tile_description.tile_shape
instruction_shape = operation.tile_description.math_instruction.instruction_shape
cluster_m = operation.tile_description.cluster_shape[0]
cluster_n = operation.tile_description.cluster_shape[1]
tile_shape_main_m, tile_shape_main_n, tile_shape_main_k = tile_shape
tile_shape_epi_m, tile_shape_epi_n, tile_shape_epi_k = tile_shape
tile_shape_m, tile_shape_n, tile_shape_k = tile_shape
# account for static/dynamic cluster shapes
cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0]
@ -902,10 +892,8 @@ ${compile_guard_end}
if opcode_class_main in [OpcodeClass.TensorOp
, OpcodeClass.BlockScaledTensorOp
]:
tile_shape_main_m = instruction_shape[0]
tile_shape_main_n = instruction_shape[1]
tile_shape_epi_m = cta_m
tile_shape_epi_n = cta_n
tile_shape_m = instruction_shape[0]
tile_shape_n = instruction_shape[1]
# stage count set to zero indicates builder automatic stage selection
@ -930,35 +918,36 @@ ${compile_guard_end}
}
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void:
if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void:
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
else:
epilogue_functor = self.epilogue_functor.emit_declaration()
if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void:
if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void:
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
#
# Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
is_no_smem_epilogue = operation.epilogue_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100:
is_no_smem_epilogue = operation.epilogue_schedule in [EpilogueScheduleType.NoSmemWarpSpecialized1Sm, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
grouped = is_grouped(operation.gemm_kind)
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped):
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100:
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped):
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>'
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
operation_name_str = operation.procedural_name()
layout_a_str = LayoutTag[instance_layout_A]
@ -1041,12 +1030,9 @@ using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape(
'opcode_class_main': OpcodeClassTag[opcode_class_main],
'opcode_class_epi': OpcodeClassTag[opcode_class_epi],
'arch': "cutlass::arch::Sm%d" % operation.arch,
'tile_shape_epi_m': str(tile_shape_epi_m),
'tile_shape_epi_n': str(tile_shape_epi_n),
'tile_shape_epi_k': str(tile_shape_epi_k),
'tile_shape_main_m': str(tile_shape_main_m),
'tile_shape_main_n': str(tile_shape_main_n),
'tile_shape_main_k': str(tile_shape_main_k),
'tile_shape_m': str(tile_shape_m),
'tile_shape_n': str(tile_shape_n),
'tile_shape_k': str(tile_shape_k),
'cluster_shape_m': 'cute::_' + str(operation.tile_description.cluster_shape[0]) if operation.tile_description.cluster_shape[0] > 0 else "int",
'cluster_shape_n': 'cute::_' + str(operation.tile_description.cluster_shape[1]) if operation.tile_description.cluster_shape[1] > 0 else "int",
'cluster_shape_k': 'cute::_' + str(operation.tile_description.cluster_shape[2]) if operation.tile_description.cluster_shape[2] > 0 else "int",
@ -1396,7 +1382,8 @@ class EmitGemmConfigurationLibrary:
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
GemmKind.Grouped: EmitGemmGroupedInstance,
GemmKind.GroupedGemmUniversal3x: EmitGemmUniversal3xInstance,
GemmKind.GroupedUniversal3x: EmitGemmUniversal3xInstance,
GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance,
}
self.gemm_kind_wrappers = {
@ -1409,7 +1396,8 @@ class EmitGemmConfigurationLibrary:
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation',
GemmKind.Grouped: 'GemmGroupedOperation',
GemmKind.GroupedGemmUniversal3x: 'GroupedGemmUniversal3xOperation'
GemmKind.GroupedUniversal3x: 'GroupedGemmUniversal3xOperation',
GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation',
}
self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"

View File

@ -217,8 +217,7 @@ def CreateGemmUniversal3xOperator(
gemm_op_extra_args["ScaleFactorB"] = data_type["sf_type"]
gemm_op_extra_args["ScaleFactorD"] = { "tensor": TensorDescription(data_type["sfd_type"]["type"], data_type["sfd_type"]["layout"]),
"vector_size" : data_type["sfd_type"]["vector_size"]}
gemm_kind = GemmKind.BlockScaledUniversal3x
assert is_block_scaled(gemm_kind)
A_dtype = data_type["a_type"]
B_dtype = data_type["b_type"]
@ -254,9 +253,6 @@ def CreateGemmUniversal3xOperator(
return operations
def is_grouped(gemm_kind):
return gemm_kind == GemmKind.GroupedGemmUniversal3x
# Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts
def CreateSparseGemmUniversal3xOperator(
manifest, layouts, tile_descriptions, data_types,
@ -6654,11 +6650,13 @@ def get_tma_alignment_elt(data_type : DataType, is_f8f6f4 : bool = True ) -> int
sm100_cluster_shape_1sm = [
[4,4,1]
, DynamicClusterShape
]
sm100_cluster_shape_2sm = [
# cluster_m % 2 == 0 for 2sm
[4,4,1]
, DynamicClusterShape
]
def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
@ -6718,6 +6716,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
]
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1], [4,4,1]
, DynamicClusterShape
]
tile_schedulers = [
@ -6765,6 +6764,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
]
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
@ -7517,8 +7517,227 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
# SM100 MMA with mixed F4/F6/F8 inputs + without block scale
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
return
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version):
# layouts for ABC and their alignments.
layouts = [
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
instruction_sizes_1sm = [
# [64, 128, 32],
[128, 128, 32],
# [64, 256, 32],
[128, 256, 32],
]
instruction_sizes_2sm = [
# [128, 128, 32],
# [128, 256, 32],
[256, 128, 32],
[256, 256, 32],
]
ab_types = [
DataType.f4, DataType.f6, DataType.f8,
DataType.e2m1, DataType.e3m2, DataType.e4m3,
]
acc_types = [ DataType.f32 ]
tile_schedulers = [
TileSchedulerType.Default, TileSchedulerType.StreamK
]
min_cc = 100
max_cc = 130
epi_type = DataType.f32
math_instructions_1sm = []
is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8)
# Usage:
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types):
is_runtime_datatype_a = is_runtime_datatype(a_type)
is_runtime_datatype_b = is_runtime_datatype(b_type)
# A/B datatypes should be both static or dynamic
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
math_instructions_1sm.append(
MathInstruction(
instr_size,
a_type, b_type, acc_type,
OpcodeClass.TensorOp,
MathOperation.multiply_add)
)
math_instructions_2sm = []
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types):
is_runtime_datatype_a = is_runtime_datatype(a_type)
is_runtime_datatype_b = is_runtime_datatype(b_type)
# A/B datatypes should be both static or dynamic
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
math_instructions_2sm.append(
MathInstruction(
instr_size,
a_type, b_type, acc_type,
OpcodeClass.TensorOp,
MathOperation.multiply_add)
)
cluster_shapes_1sm = [
# [1,2,1],
[2,1,1],
[1,1,1],
# [1,4,1],
[4,4,1]
, DynamicClusterShape
]
# 1xSM MMA kernels
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_1sm:
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_1sm[0],
math_inst.instruction_shape[1] * multiplier_1sm[1],
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
kernel_data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.f32,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.e5m2,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
}
]
for kernel_data_type in kernel_data_types:
# Filter out some kernel
if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\
( kernel_data_type["d_type"] == DataType.e5m2 ):
continue
# Update layout alignment
# alignment for d might be different for each kernel_data_type
layouts_copy = copy.deepcopy(layouts)
for layout in layouts_copy:
# alignment for a
layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"])
# alignment for b
layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"])
# alignment for d
layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"])
CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type],
[[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], tile_schedulers=tile_schedulers)
cluster_shapes_2sm = [
[2,1,1],
# [2,2,1],
# [2,4,1],
# [4,1,1],
# [4,2,1],
[4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_2sm[0],
math_inst.instruction_shape[1] * multiplier_2sm[1],
math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
kernel_data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.f32,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.e5m2,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
}
]
for kernel_data_type in kernel_data_types:
# Filter some kernel
if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\
( kernel_data_type["d_type"] == DataType.e5m2 ):
continue
# Update layout alignment
# alignment for d might be different for each kernel_data_type
layouts_copy = copy.deepcopy(layouts)
for layout in layouts_copy:
# alignment for a
layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"])
# alignment for b
layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"])
# alignment for d
layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"])
if math_inst.instruction_shape[0] == 128:
CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type],
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], tile_schedulers=tile_schedulers)
else:
CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type],
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]], tile_schedulers=tile_schedulers)
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
# SM100 MMA with mixed F4/F6/F8 inputs + block scale
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
@ -7529,7 +7748,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
]
instruction_sizes_1sm = [
[128, 128, 32], [128, 256, 32], # Mixed F4/F6/F8 block scaled only supports M=128 for 1SM cases
[128, 128, 32], [128, 256, 32], # Block scaled kernels only support M=128 for 1SM cases
]
instruction_sizes_2sm = [
@ -7670,8 +7889,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
for data_type in data_types:
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]]
, tile_schedulers = tile_schedulers(data_type["sfd_type"])
)
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
cluster_shapes_2sm = [
[2,1,1],
@ -7766,21 +7984,21 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
if math_inst.instruction_shape[0] == 128:
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]]
, tile_schedulers = tile_schedulers(data_type["sfd_type"])
)
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
else:
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]]
, tile_schedulers = tile_schedulers(data_type["sfd_type"])
)
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version):
def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
# SM100 MMA with F4 + block scale
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
grouped = is_grouped(gemm_kind)
# layouts for ABC and their alignments.
layouts = [
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]],
@ -7805,7 +8023,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
def tile_schedulers(sfdtype):
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
# the epilogue is the traditional linear combination, for which we already have tests with stream-K.
if sfdtype["type"] == DataType.void:
if sfdtype["type"] == DataType.void or grouped:
return [TileSchedulerType.Default]
else:
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
@ -7826,6 +8044,10 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
# grouped GEMM does not support runtime data type yet
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
continue
math_instructions_1sm.append(
MathInstruction(
instr_size,
@ -7853,6 +8075,10 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
# grouped GEMM does not support runtime data type yet
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
continue
math_instructions_2sm.append(
MathInstruction(
instr_size,
@ -7972,15 +8198,21 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
for data_type in data_types:
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
# E2M1 x E2M1, vector size 32, E8
# E2M1 x E2M1, vector size 16, UE4M3
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
nvfp4_schedule = [KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]
fp4_schedule = [KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped)
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped)
nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule]
fp4_schedule = [fp4_kernel_schedule, epi_schedule]
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule]
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind
)
if isFp4:
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule]
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind
)
cluster_shapes_2sm = [
@ -8085,18 +8317,20 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
for data_type in data_types:
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
# E2M1 x E2M1, vector size 32, E8
# E2M1 x E2M1, vector size 32, E8
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
nvfp4_schedule = [KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]
fp4_schedule = [KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]
epi_schedule = EpilogueScheduleType.ScheduleAuto if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped)
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped)
nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule]
fp4_schedule = [fp4_kernel_schedule, epi_schedule]
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule]
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
)
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
if isFp4:
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule]
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
)
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
@ -8139,6 +8373,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
MathOperation.multiply_add)]
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1]
, DynamicClusterShape
]
tile_schedulers = [
@ -8237,6 +8472,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
]
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
@ -8353,6 +8589,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
cluster_shapes_1sm = [
[1,2,1], [1,1,1], [1,4,1], [4,4,1]
, DynamicClusterShape
]
tile_schedulers = [
@ -8386,6 +8623,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
cluster_shapes_2sm = [
[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
@ -8431,6 +8669,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
cluster_shapes_1sm = [
[1,2,1], [1,1,1], [4,4,1]
, DynamicClusterShape
]
tile_schedulers = [
@ -8498,6 +8737,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
cluster_shapes_2sm = [
[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
@ -8554,6 +8794,125 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers)
def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
return
# layouts for ABC and their alignments.
layouts = [
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]],
[[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]],
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]],
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]],
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
]
min_cc = 100
max_cc = 130
epi_type = DataType.f32
math_instructions_1sm = [
MathInstruction(
[128, 256, 32],
DataType.e4m3, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add)]
cluster_shapes_1sm = [
[1,2,1], [2,1,1], [1,1,1], [4,4,1]
, DynamicClusterShape
]
tile_schedulers = [
TileSchedulerType.StreamK,
]
# 1xSM MMA kernels
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_1sm:
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_1sm[0],
math_inst.instruction_shape[1] * multiplier_1sm[1],
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.f16,
"d_type" : DataType.e4m3,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
}]
# Set alignment d based on Destination format.
for layout in layouts:
layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]]
for data_type in data_types:
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
( data_type["d_type"] == DataType.e5m2 ):
continue
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]],
tile_schedulers=tile_schedulers)
# 2xSM MMA kernels
math_instructions_2sm = [
MathInstruction(
[256, 256, 32],
DataType.e4m3, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
]
cluster_shapes_2sm = [
[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_2sm[0],
math_inst.instruction_shape[1] * multiplier_2sm[1],
math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.f16,
"d_type" : DataType.e4m3,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
}]
# Set alignment d based on Destination format.
for layout in layouts:
layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]]
for data_type in data_types:
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
( data_type["d_type"] == DataType.e5m2 ):
continue
if math_inst.instruction_shape[0] == 128:
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm
else:
epi_schedule = EpilogueScheduleType.ScheduleAuto
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers)
def GenerateSM100(manifest, cuda_version):
#
@ -8570,13 +8929,19 @@ def GenerateSM100(manifest, cuda_version):
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version)
# grouped GEMM
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version)
# StreamK is included in regular generation
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version)
#
# Block Scaled Gemm
#
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version)
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
###################################################################################################
@ -8955,8 +9320,8 @@ def GenerateSM90(manifest, cuda_version):
GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version)
GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version)
GenerateSM90_TensorOp_1684(manifest, cuda_version)
GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
GenerateSM90_TensorOp_1684_complex(manifest, cuda_version)
GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version)
GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version)

View File

@ -321,6 +321,12 @@ def is_complex(data_type):
return True
return False
def is_block_scaled(gemm_kind):
return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x)
def is_grouped(gemm_kind):
return gemm_kind in (GemmKind.GroupedUniversal3x, GemmKind.GroupedBlockScaledUniversal3x)
#
def get_complex_from_real(real_type):
for r, c in RealComplexBijection:
@ -482,23 +488,32 @@ class KernelScheduleType(enum.Enum):
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
ImplicitTmaWarpSpecializedSm90 = enum_auto()
TmaWarpSpecialized1SmSm100 = enum_auto()
TmaWarpSpecialized2SmSm100 = enum_auto()
PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto()
PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto()
PtrArrayTmaWarpSpecialized1SmBlockScaledSm100 = enum_auto()
PtrArrayTmaWarpSpecialized2SmBlockScaledSm100 = enum_auto()
PtrArrayNvf4TmaWarpSpecialized1SmSm100 = enum_auto()
PtrArrayNvf4TmaWarpSpecialized2SmSm100 = enum_auto()
PtrArrayMxf4TmaWarpSpecialized1SmSm100 = enum_auto()
PtrArrayMxf4TmaWarpSpecialized2SmSm100 = enum_auto()
PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto()
BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto()
Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
Mxf4TmaWarpSpecialized1SmSm100 = enum_auto()
Mxf4TmaWarpSpecialized2SmSm100 = enum_auto()
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
KernelPtrArrayTmaWarpSpecializedCooperative = enum_auto()
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
KernelPtrArrayTmaWarpSpecializedPingpong = enum_auto()
@ -519,7 +534,7 @@ KernelScheduleTag = {
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum',
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum',
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90',
KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100',
KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100',
@ -530,16 +545,25 @@ KernelScheduleTag = {
KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100',
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100',
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100',
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100",
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100",
KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100",
KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100",
KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100",
KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100",
}
#
@ -568,16 +592,25 @@ KernelScheduleSuffixes = {
KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm',
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm',
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm',
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm',
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
}
class EpilogueScheduleType(enum.Enum):
@ -585,6 +618,10 @@ class EpilogueScheduleType(enum.Enum):
EpilogueTransposed = enum_auto()
NoSmemWarpSpecialized = enum_auto()
PtrArrayNoSmemWarpSpecialized = enum_auto()
NoSmemWarpSpecialized1Sm = enum_auto()
NoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
TmaWarpSpecialized1Sm = enum_auto()
@ -600,6 +637,10 @@ EpilogueScheduleTag = {
EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed',
EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized',
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm',
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm',
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
@ -616,6 +657,10 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.EpilogueTransposed: '',
EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem',
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
@ -636,6 +681,23 @@ EpilogueFunctor3xTag = {
EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor',
}
def to_grouped_schedule(schedule, grouped):
if not grouped:
return schedule
group_schedule_map = {
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100,
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100,
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100,
EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
}
return group_schedule_map[schedule]
class TileSchedulerType(enum.Enum):
Default = enum_auto()
Persistent = enum_auto()
@ -817,7 +879,8 @@ class GemmKind(enum.Enum):
PlanarComplexArray = enum_auto()
Grouped = enum_auto()
BlockScaledUniversal3x = enum_auto()
GroupedGemmUniversal3x = enum_auto()
GroupedUniversal3x = enum_auto()
GroupedBlockScaledUniversal3x = enum_auto()
#
GemmKindNames = {
@ -830,7 +893,8 @@ GemmKindNames = {
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
GemmKind.Grouped: "gemm_grouped",
GemmKind.BlockScaledUniversal3x: "gemm_block_scaled",
GemmKind.GroupedGemmUniversal3x: "gemm_grouped",
GemmKind.GroupedUniversal3x: "gemm_grouped",
GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped_block_scaled"
}
#

View File

@ -489,7 +489,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
if is_fp32 and (is_tn or is_nn) and (cta_n % cta_k != 0):
return [], []
grouped = gemm_kind == GemmKind.GroupedGemmUniversal3x
grouped = is_grouped(gemm_kind)
if grouped:
# the following cases are unsupported by grouped GEMM
if not is_aligned: