v3.8.0 update (#2082)
* 3.8 update * fix Markus' name --------- Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@ -64,7 +64,7 @@ class GemmOperation:
|
||||
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
|
||||
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
|
||||
tile_scheduler = TileSchedulerType.Default
|
||||
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False
|
||||
|
||||
, ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None
|
||||
|
||||
@ -74,6 +74,7 @@ class GemmOperation:
|
||||
GemmKind.Universal3x,
|
||||
GemmKind.SparseUniversal3x,
|
||||
GemmKind.BlockScaledUniversal3x,
|
||||
GemmKind.GroupedGemmUniversal3x,
|
||||
}
|
||||
self.is_3x = gemm_kind in kinds_3x
|
||||
self.prefix = "3x" if self.is_3x else ""
|
||||
@ -111,6 +112,12 @@ class GemmOperation:
|
||||
self.swizzling_functor = swizzling_functor
|
||||
self.tile_scheduler = tile_scheduler
|
||||
|
||||
# Only enable mixed input mode and mixed input shuffle for Hopper
|
||||
self.mixed_input_mode = None
|
||||
if self.is_mixed_input() and self.arch >= 90 and self.arch < 100:
|
||||
self.mixed_input_mode = mixed_input_mode
|
||||
self.mixed_input_shuffle = (self.mixed_input_mode is not None) and mixed_input_shuffle
|
||||
|
||||
#
|
||||
def is_complex(self):
|
||||
complex_operators = [
|
||||
@ -211,6 +218,18 @@ class GemmOperation:
|
||||
|
||||
return extended_name
|
||||
|
||||
#
|
||||
def mixed_input_mode_name(self):
|
||||
mode_name_mapping = {
|
||||
MixedInputMode.ConvertOnly: "_cvt",
|
||||
MixedInputMode.ScaleOnly: "_scl",
|
||||
MixedInputMode.ScaleWithZeroPoint: "_sclzr"
|
||||
}
|
||||
mode_name = mode_name_mapping.get(self.mixed_input_mode, "")
|
||||
if self.mixed_input_shuffle:
|
||||
mode_name = mode_name + "_shfl"
|
||||
return mode_name
|
||||
|
||||
def extended_name_3x(self):
|
||||
'''Generates a string representing the MMA atom. Assumes accumulator type is C type.'''
|
||||
extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
||||
@ -237,6 +256,8 @@ class GemmOperation:
|
||||
element_d = d_type_names,
|
||||
core_name = self.core_name())
|
||||
|
||||
if self.mixed_input_mode != None:
|
||||
extended_name = extended_name + self.mixed_input_mode_name()
|
||||
return extended_name
|
||||
|
||||
def datatype_name_3x(self):
|
||||
@ -768,6 +789,8 @@ using ${operation_name}_epilogue =
|
||||
${epilogue_functor}
|
||||
>::CollectiveOp;
|
||||
|
||||
${mixed_dtype_prepare_code}
|
||||
|
||||
using ${operation_name}_mainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
${arch}, ${opcode_class_main},
|
||||
@ -782,7 +805,7 @@ using ${operation_name}_mainloop =
|
||||
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
${problem_shape},
|
||||
${operation_name}_mainloop,
|
||||
${operation_name}_epilogue,
|
||||
${tile_scheduler}>;
|
||||
@ -830,7 +853,18 @@ ${compile_guard_end}
|
||||
return SubstituteTemplate(block_scaled_template, block_scaled_values)
|
||||
|
||||
|
||||
#
|
||||
@staticmethod
|
||||
def pointerize_if_grouped(operation, layout):
|
||||
return layout if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else layout + "* "
|
||||
|
||||
@staticmethod
|
||||
def problem_shape(operation):
|
||||
gemm_shape_type = "cute::Shape<int,int,int,int>"
|
||||
grouped_gemm_shape_type = "cute::Shape<int,int,int>"
|
||||
grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">"
|
||||
|
||||
return gemm_shape_type if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else grouped_gemm_shape_type
|
||||
|
||||
def emit(self, operation):
|
||||
_LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)")
|
||||
_LOGGER.debug("*** operation.procedural_name(): " + operation.procedural_name())
|
||||
@ -926,17 +960,83 @@ ${compile_guard_end}
|
||||
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
|
||||
|
||||
|
||||
operation_name_str = operation.procedural_name()
|
||||
layout_a_str = LayoutTag[instance_layout_A]
|
||||
layout_b_str = LayoutTag[instance_layout_B]
|
||||
mixed_dtype_prepare_code = ""
|
||||
if operation.mixed_input_mode != None:
|
||||
A_dtype = operation.A.element
|
||||
B_dtype = operation.B.element
|
||||
A_dtype_bits = DataTypeSize[A_dtype]
|
||||
B_dtype_bits = DataTypeSize[B_dtype]
|
||||
is_A_dtype_narrow = A_dtype_bits < B_dtype_bits
|
||||
if is_A_dtype_narrow:
|
||||
narrow_dtype, wide_dtype = (A_dtype, B_dtype)
|
||||
narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits)
|
||||
else:
|
||||
narrow_dtype, wide_dtype = (B_dtype, A_dtype)
|
||||
narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits)
|
||||
|
||||
narrow_tag = DataTypeTag[narrow_dtype]
|
||||
wide_tag = DataTypeTag[wide_dtype]
|
||||
scale_tag = DataTypeTag[wide_dtype]
|
||||
zero_tag = DataTypeTag[wide_dtype]
|
||||
|
||||
do_shuffle = False
|
||||
value_shuffle_str = ""
|
||||
if narrow_dtype_bits == 4 and wide_dtype_bits == 16:
|
||||
value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_4>, cute::Stride<cute::_4,cute::_1>>"
|
||||
do_shuffle = True
|
||||
if narrow_dtype_bits == 8 and wide_dtype_bits == 16:
|
||||
value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_2>, cute::Stride<cute::_2,cute::_1>>"
|
||||
do_shuffle = True
|
||||
do_shuffle = operation.mixed_input_shuffle and do_shuffle
|
||||
|
||||
if do_shuffle:
|
||||
if is_A_dtype_narrow:
|
||||
stride_narrow_str = f"cutlass::detail::TagToStrideA_t<{layout_a_str}>"
|
||||
layout_a_str = f"{operation_name_str}_LayoutNarrowReordered"
|
||||
else:
|
||||
stride_narrow_str = f"cutlass::detail::TagToStrideB_t<{layout_b_str}>"
|
||||
layout_b_str = f"{operation_name_str}_LayoutNarrowReordered"
|
||||
# The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and
|
||||
# layout_{a, b}_str are to prevent errors in Windows platform unity build
|
||||
mixed_dtype_prepare_code = f"""
|
||||
using {operation_name_str}_StrideNarrow = {stride_narrow_str};
|
||||
using {operation_name_str}_ValueShuffle = {value_shuffle_str};
|
||||
static constexpr int {operation_name_str}_NumShuffleAtoms = 1;
|
||||
using {operation_name_str}_MmaAtomShape = cute::Layout<cute::Shape<cute::_1, cute::Int<{operation_name_str}_NumShuffleAtoms>>>;
|
||||
using {operation_name_str}_LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, {operation_name_str}_ValueShuffle>());
|
||||
using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, cute::Layout<cute::Shape<int,int,int>, {operation_name_str}_StrideNarrow>{{}}));
|
||||
"""
|
||||
|
||||
mixed_input_modes_to_element = {
|
||||
MixedInputMode.ConvertOnly: narrow_tag,
|
||||
MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>",
|
||||
MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>"
|
||||
}
|
||||
narrow_element = mixed_input_modes_to_element.get(operation.mixed_input_mode, narrow_tag)
|
||||
|
||||
if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2):
|
||||
narrow_element = f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>"
|
||||
|
||||
if is_A_dtype_narrow:
|
||||
element_a = narrow_element
|
||||
else:
|
||||
element_b = narrow_element
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'operation_name': operation_name_str,
|
||||
'operation_suffix': self.operation_suffix,
|
||||
'problem_shape': self.problem_shape(operation),
|
||||
'element_a': element_a,
|
||||
'layout_a': LayoutTag[instance_layout_A],
|
||||
'layout_a': self.pointerize_if_grouped(operation, layout_a_str),
|
||||
'element_b': element_b,
|
||||
'layout_b': LayoutTag[instance_layout_B],
|
||||
'layout_b': self.pointerize_if_grouped(operation, layout_b_str),
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[instance_layout_C],
|
||||
'layout_c': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_C]),
|
||||
'element_d': DataTypeTag[operation.D.element],
|
||||
'layout_d': LayoutTag[instance_layout_D],
|
||||
'layout_d': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_D]),
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class_main': OpcodeClassTag[opcode_class_main],
|
||||
'opcode_class_epi': OpcodeClassTag[opcode_class_epi],
|
||||
@ -968,6 +1068,7 @@ ${compile_guard_end}
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]),
|
||||
'mixed_dtype_prepare_code': mixed_dtype_prepare_code
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.gemm_template, values)
|
||||
@ -1294,7 +1395,8 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.BlockScaledUniversal3x: EmitGemmUniversal3xInstance,
|
||||
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
|
||||
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
|
||||
GemmKind.Grouped: EmitGemmGroupedInstance
|
||||
GemmKind.Grouped: EmitGemmGroupedInstance,
|
||||
GemmKind.GroupedGemmUniversal3x: EmitGemmUniversal3xInstance,
|
||||
}
|
||||
|
||||
self.gemm_kind_wrappers = {
|
||||
@ -1306,7 +1408,8 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.BlockScaledUniversal3x: 'BlockScaledGemmUniversal3xOperation',
|
||||
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
|
||||
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation',
|
||||
GemmKind.Grouped: 'GemmGroupedOperation'
|
||||
GemmKind.Grouped: 'GemmGroupedOperation',
|
||||
GemmKind.GroupedGemmUniversal3x: 'GroupedGemmUniversal3xOperation'
|
||||
}
|
||||
|
||||
self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
|
||||
@ -1363,6 +1466,7 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
("library_internal.h", None),
|
||||
("gemm_operation.h", None),
|
||||
("gemm_operation_3x.hpp", None),
|
||||
("grouped_gemm_operation_3x.hpp", None),
|
||||
("sparse_gemm_operation_3x.hpp", None),
|
||||
("block_scaled_gemm_operation_3x.hpp", None),
|
||||
("cutlass/arch/wmma.h", None),
|
||||
|
||||
Reference in New Issue
Block a user