v3.8.0 update (#2082)

* 3.8 update

* fix Markus' name

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-02-06 18:33:40 -08:00
committed by GitHub
parent affd1b693d
commit 833f6990e0
168 changed files with 24945 additions and 3436 deletions

View File

@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@ -64,7 +64,7 @@ class GemmOperation:
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
tile_scheduler = TileSchedulerType.Default
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False
, ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None
@ -74,6 +74,7 @@ class GemmOperation:
GemmKind.Universal3x,
GemmKind.SparseUniversal3x,
GemmKind.BlockScaledUniversal3x,
GemmKind.GroupedGemmUniversal3x,
}
self.is_3x = gemm_kind in kinds_3x
self.prefix = "3x" if self.is_3x else ""
@ -111,6 +112,12 @@ class GemmOperation:
self.swizzling_functor = swizzling_functor
self.tile_scheduler = tile_scheduler
# Only enable mixed input mode and mixed input shuffle for Hopper
self.mixed_input_mode = None
if self.is_mixed_input() and self.arch >= 90 and self.arch < 100:
self.mixed_input_mode = mixed_input_mode
self.mixed_input_shuffle = (self.mixed_input_mode is not None) and mixed_input_shuffle
#
def is_complex(self):
complex_operators = [
@ -211,6 +218,18 @@ class GemmOperation:
return extended_name
#
def mixed_input_mode_name(self):
mode_name_mapping = {
MixedInputMode.ConvertOnly: "_cvt",
MixedInputMode.ScaleOnly: "_scl",
MixedInputMode.ScaleWithZeroPoint: "_sclzr"
}
mode_name = mode_name_mapping.get(self.mixed_input_mode, "")
if self.mixed_input_shuffle:
mode_name = mode_name + "_shfl"
return mode_name
def extended_name_3x(self):
'''Generates a string representing the MMA atom. Assumes accumulator type is C type.'''
extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
@ -237,6 +256,8 @@ class GemmOperation:
element_d = d_type_names,
core_name = self.core_name())
if self.mixed_input_mode != None:
extended_name = extended_name + self.mixed_input_mode_name()
return extended_name
def datatype_name_3x(self):
@ -768,6 +789,8 @@ using ${operation_name}_epilogue =
${epilogue_functor}
>::CollectiveOp;
${mixed_dtype_prepare_code}
using ${operation_name}_mainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
${arch}, ${opcode_class_main},
@ -782,7 +805,7 @@ using ${operation_name}_mainloop =
// Gemm operator ${operation_name}
using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
${problem_shape},
${operation_name}_mainloop,
${operation_name}_epilogue,
${tile_scheduler}>;
@ -830,7 +853,18 @@ ${compile_guard_end}
return SubstituteTemplate(block_scaled_template, block_scaled_values)
#
@staticmethod
def pointerize_if_grouped(operation, layout):
return layout if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else layout + "* "
@staticmethod
def problem_shape(operation):
gemm_shape_type = "cute::Shape<int,int,int,int>"
grouped_gemm_shape_type = "cute::Shape<int,int,int>"
grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">"
return gemm_shape_type if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else grouped_gemm_shape_type
def emit(self, operation):
_LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)")
_LOGGER.debug("*** operation.procedural_name(): " + operation.procedural_name())
@ -926,17 +960,83 @@ ${compile_guard_end}
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
operation_name_str = operation.procedural_name()
layout_a_str = LayoutTag[instance_layout_A]
layout_b_str = LayoutTag[instance_layout_B]
mixed_dtype_prepare_code = ""
if operation.mixed_input_mode != None:
A_dtype = operation.A.element
B_dtype = operation.B.element
A_dtype_bits = DataTypeSize[A_dtype]
B_dtype_bits = DataTypeSize[B_dtype]
is_A_dtype_narrow = A_dtype_bits < B_dtype_bits
if is_A_dtype_narrow:
narrow_dtype, wide_dtype = (A_dtype, B_dtype)
narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits)
else:
narrow_dtype, wide_dtype = (B_dtype, A_dtype)
narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits)
narrow_tag = DataTypeTag[narrow_dtype]
wide_tag = DataTypeTag[wide_dtype]
scale_tag = DataTypeTag[wide_dtype]
zero_tag = DataTypeTag[wide_dtype]
do_shuffle = False
value_shuffle_str = ""
if narrow_dtype_bits == 4 and wide_dtype_bits == 16:
value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_4>, cute::Stride<cute::_4,cute::_1>>"
do_shuffle = True
if narrow_dtype_bits == 8 and wide_dtype_bits == 16:
value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_2>, cute::Stride<cute::_2,cute::_1>>"
do_shuffle = True
do_shuffle = operation.mixed_input_shuffle and do_shuffle
if do_shuffle:
if is_A_dtype_narrow:
stride_narrow_str = f"cutlass::detail::TagToStrideA_t<{layout_a_str}>"
layout_a_str = f"{operation_name_str}_LayoutNarrowReordered"
else:
stride_narrow_str = f"cutlass::detail::TagToStrideB_t<{layout_b_str}>"
layout_b_str = f"{operation_name_str}_LayoutNarrowReordered"
# The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and
# layout_{a, b}_str are to prevent errors in Windows platform unity build
mixed_dtype_prepare_code = f"""
using {operation_name_str}_StrideNarrow = {stride_narrow_str};
using {operation_name_str}_ValueShuffle = {value_shuffle_str};
static constexpr int {operation_name_str}_NumShuffleAtoms = 1;
using {operation_name_str}_MmaAtomShape = cute::Layout<cute::Shape<cute::_1, cute::Int<{operation_name_str}_NumShuffleAtoms>>>;
using {operation_name_str}_LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, {operation_name_str}_ValueShuffle>());
using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, cute::Layout<cute::Shape<int,int,int>, {operation_name_str}_StrideNarrow>{{}}));
"""
mixed_input_modes_to_element = {
MixedInputMode.ConvertOnly: narrow_tag,
MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>",
MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>"
}
narrow_element = mixed_input_modes_to_element.get(operation.mixed_input_mode, narrow_tag)
if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2):
narrow_element = f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>"
if is_A_dtype_narrow:
element_a = narrow_element
else:
element_b = narrow_element
values = {
'operation_name': operation.procedural_name(),
'operation_name': operation_name_str,
'operation_suffix': self.operation_suffix,
'problem_shape': self.problem_shape(operation),
'element_a': element_a,
'layout_a': LayoutTag[instance_layout_A],
'layout_a': self.pointerize_if_grouped(operation, layout_a_str),
'element_b': element_b,
'layout_b': LayoutTag[instance_layout_B],
'layout_b': self.pointerize_if_grouped(operation, layout_b_str),
'element_c': DataTypeTag[operation.C.element],
'layout_c': LayoutTag[instance_layout_C],
'layout_c': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_C]),
'element_d': DataTypeTag[operation.D.element],
'layout_d': LayoutTag[instance_layout_D],
'layout_d': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_D]),
'element_accumulator': DataTypeTag[operation.accumulator_type()],
'opcode_class_main': OpcodeClassTag[opcode_class_main],
'opcode_class_epi': OpcodeClassTag[opcode_class_epi],
@ -968,6 +1068,7 @@ ${compile_guard_end}
'epilogue_vector_length': str(epilogue_vector_length),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]),
'mixed_dtype_prepare_code': mixed_dtype_prepare_code
}
return SubstituteTemplate(self.gemm_template, values)
@ -1294,7 +1395,8 @@ class EmitGemmConfigurationLibrary:
GemmKind.BlockScaledUniversal3x: EmitGemmUniversal3xInstance,
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
GemmKind.Grouped: EmitGemmGroupedInstance
GemmKind.Grouped: EmitGemmGroupedInstance,
GemmKind.GroupedGemmUniversal3x: EmitGemmUniversal3xInstance,
}
self.gemm_kind_wrappers = {
@ -1306,7 +1408,8 @@ class EmitGemmConfigurationLibrary:
GemmKind.BlockScaledUniversal3x: 'BlockScaledGemmUniversal3xOperation',
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation',
GemmKind.Grouped: 'GemmGroupedOperation'
GemmKind.Grouped: 'GemmGroupedOperation',
GemmKind.GroupedGemmUniversal3x: 'GroupedGemmUniversal3xOperation'
}
self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
@ -1363,6 +1466,7 @@ void initialize_${configuration_name}(Manifest &manifest) {
("library_internal.h", None),
("gemm_operation.h", None),
("gemm_operation_3x.hpp", None),
("grouped_gemm_operation_3x.hpp", None),
("sparse_gemm_operation_3x.hpp", None),
("block_scaled_gemm_operation_3x.hpp", None),
("cutlass/arch/wmma.h", None),