update 3.8 v2 (#2112)
* update 3.8 v2 * update 3.8 --------- Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
@ -532,29 +532,12 @@ def tuple_factory_(input_tuple, dtype, constants=[0,1]):
|
||||
if first_non_empty_base is None:
|
||||
first_non_empty_base = []
|
||||
|
||||
# Determine whether or not add an additional byte for empty base classes
|
||||
additional_byte = False
|
||||
# Special case for constant tuple
|
||||
if first_non_empty_base is None:
|
||||
additional_byte = False
|
||||
else:
|
||||
for base in first_non_empty_base:
|
||||
if base in empty_bases:
|
||||
additional_byte = True
|
||||
break
|
||||
|
||||
if additional_byte:
|
||||
ctype_fields = [("empty_byte", EmptyByte), ] + ctype_fields
|
||||
|
||||
# Create the ctype tuple
|
||||
class TupleType(ctypes.Structure):
|
||||
_fields_ = ctype_fields
|
||||
|
||||
def __init__(self, args) -> None:
|
||||
if additional_byte:
|
||||
fields = self._fields_[1:]
|
||||
else:
|
||||
fields = self._fields_
|
||||
fields = self._fields_
|
||||
|
||||
assert len(fields) == len(args)
|
||||
for field, arg in zip(fields, args):
|
||||
|
||||
@ -69,7 +69,7 @@ using ${operation_name}_epilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
${arch},
|
||||
${opcode_class_epi},
|
||||
${output_cta_tile_shape}, // output cta tile shape
|
||||
${mma_tile_shape}, // mma tile shape
|
||||
${cluster_shape}, // cluster shape
|
||||
${epi_tile_mn},
|
||||
${element_accumulator},
|
||||
@ -109,26 +109,6 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
|
||||
def arch_number_to_type(self, arch: int) -> str:
|
||||
return f"cutlass::arch::Sm{arch}"
|
||||
|
||||
def output_cta_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
|
||||
# For all three kinds of convolutions, the tile shape's K mode
|
||||
# differs from GEMM in that needs to be wrapped in a Shape.
|
||||
# For Wgrad convolutions specifically,
|
||||
# the N tile shape also needs to be wrapped in a Shape.
|
||||
m_template = 'cute::_${cta_m}'
|
||||
if operation.conv_kind == ConvKind.Wgrad:
|
||||
n_template = 'cute::Shape<cute::_${cta_n}>'
|
||||
else:
|
||||
n_template = 'cute::_${cta_n}'
|
||||
k_template = 'cute::Shape<cute::_${cta_k}>'
|
||||
|
||||
output_cta_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
|
||||
values = {
|
||||
'cta_m': cta_m,
|
||||
'cta_n': cta_n,
|
||||
'cta_k': cta_k
|
||||
}
|
||||
return Template(output_cta_tile_shape_template).substitute(values)
|
||||
|
||||
def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
|
||||
mma_m = cta_m
|
||||
mma_n = cta_n
|
||||
@ -223,7 +203,6 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': opcode_class,
|
||||
'arch': self.arch_number_to_type(operation.arch),
|
||||
'output_cta_tile_shape': self.output_cta_tile_shape(operation, cta_m, cta_n, cta_k),
|
||||
'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k),
|
||||
'cluster_shape': self.cluster_shape(operation),
|
||||
'opcode_class_epi': opcode_class_epi,
|
||||
|
||||
@ -90,19 +90,32 @@ def hash_cutlass_string(input_string):
|
||||
def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b):
|
||||
# Define a dictionary mapping the detected types to runtime values
|
||||
datatype_map = {
|
||||
'_f4_': '_' + runtime_datatype_a + '_',
|
||||
'_f6_': '_' + runtime_datatype_b + '_',
|
||||
'_f8_': '_' + runtime_datatype_a + '_',
|
||||
'f4_f4': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f4_f6': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f4_f8': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f6_f4': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f6_f6': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f6_f8': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f8_f4': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f8_f6': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'f8_f8': runtime_datatype_a + '_' + runtime_datatype_b,
|
||||
'ue8m0xf4_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue4m3xf4_ue4m3xf4': 'ue4m3x' + runtime_datatype_a + '_ue4m3x' + runtime_datatype_b,
|
||||
'ue8m0xf4_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue8m0xf4_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue8m0xf6_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue8m0xf6_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue8m0xf8_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue8m0xf8_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
'ue8m0xf8_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
||||
}
|
||||
|
||||
# Use regex to identify and replace _f4_, _f6_, or _f8_ in the kernel name
|
||||
def substitute(match):
|
||||
datatype = match.group(0) # This is the matched "_f4_", "_f6_", or "_f8_"
|
||||
return datatype_map.get(datatype, datatype) # Replace or leave as is
|
||||
# Regular expression to detect all the keys in datatype_map
|
||||
pattern = re.compile(r'(' + '|'.join(map(re.escape, datatype_map.keys())) + r')')
|
||||
|
||||
# Replace detected patterns using the dictionary
|
||||
updated_kernel_name = pattern.sub(lambda match: datatype_map[match.group(0)], hashed_kernel_name)
|
||||
|
||||
# Regex to find "_f4_", "_f6_", or "_f8_" in the hashed_kernel_name
|
||||
updated_kernel_name = re.sub(r'_f4_|_f6_|_f8_', substitute, hashed_kernel_name)
|
||||
|
||||
return updated_kernel_name
|
||||
|
||||
# This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k.
|
||||
|
||||
@ -64,17 +64,15 @@ class GemmOperation:
|
||||
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
|
||||
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
|
||||
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False
|
||||
|
||||
, ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None
|
||||
|
||||
):
|
||||
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False,
|
||||
ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None):
|
||||
|
||||
kinds_3x = {
|
||||
GemmKind.Universal3x,
|
||||
GemmKind.SparseUniversal3x,
|
||||
GemmKind.BlockScaledUniversal3x,
|
||||
GemmKind.GroupedGemmUniversal3x,
|
||||
GemmKind.GroupedUniversal3x,
|
||||
GemmKind.GroupedBlockScaledUniversal3x,
|
||||
}
|
||||
self.is_3x = gemm_kind in kinds_3x
|
||||
self.prefix = "3x" if self.is_3x else ""
|
||||
@ -87,13 +85,11 @@ class GemmOperation:
|
||||
self.C = C
|
||||
self.D = D
|
||||
|
||||
|
||||
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
|
||||
if is_block_scaled(gemm_kind):
|
||||
self.ScaleFactorA = ScaleFactorA
|
||||
self.ScaleFactorB = ScaleFactorB
|
||||
self.ScaleFactorD = ScaleFactorD["tensor"]
|
||||
self.ScaleFactorVectorSize = ScaleFactorD["vector_size"]
|
||||
|
||||
|
||||
if self.D == None:
|
||||
self.D = self.C
|
||||
@ -239,13 +235,13 @@ class GemmOperation:
|
||||
element_c = DataTypeNames[self.C.element],
|
||||
element_d = DataTypeNames[self.D.element],
|
||||
core_name = self.core_name())
|
||||
|
||||
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
|
||||
|
||||
if is_block_scaled(self.gemm_kind):
|
||||
d_type_names = DataTypeNames[self.D.element]
|
||||
|
||||
|
||||
if self.ScaleFactorD.element != DataType.void:
|
||||
d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names
|
||||
|
||||
|
||||
extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
||||
element_sfa = DataTypeNames[self.ScaleFactorA],
|
||||
element_a = DataTypeNames[self.A.element],
|
||||
@ -255,7 +251,7 @@ class GemmOperation:
|
||||
element_c = DataTypeNames[self.C.element],
|
||||
element_d = d_type_names,
|
||||
core_name = self.core_name())
|
||||
|
||||
|
||||
if self.mixed_input_mode != None:
|
||||
extended_name = extended_name + self.mixed_input_mode_name()
|
||||
return extended_name
|
||||
@ -298,8 +294,8 @@ class GemmOperation:
|
||||
|
||||
# Generates a short string representing underlying epilogue schedule type
|
||||
def epilogue_schedule_name_3x(self):
|
||||
|
||||
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
|
||||
|
||||
if is_block_scaled(self.gemm_kind):
|
||||
if self.ScaleFactorD.element != DataType.void:
|
||||
return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout]
|
||||
|
||||
@ -779,7 +775,7 @@ class EmitGemmUniversal3xInstance:
|
||||
using ${operation_name}_epilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
${arch}, ${opcode_class_epi},
|
||||
cute::Shape<cute::_${tile_shape_epi_m}, cute::_${tile_shape_epi_n}, cute::_${tile_shape_epi_k}>,
|
||||
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
|
||||
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
|
||||
${epi_tile_mn},
|
||||
${element_accumulator}, ${element_epilogue},
|
||||
@ -797,7 +793,7 @@ using ${operation_name}_mainloop =
|
||||
${element_a}, ${layout_a}, ${align_a},
|
||||
${element_b}, ${layout_b}, ${align_b},
|
||||
${element_accumulator},
|
||||
cute::Shape<cute::_${tile_shape_main_m}, cute::_${tile_shape_main_n}, cute::_${tile_shape_main_k}>,
|
||||
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
|
||||
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
|
||||
${stages},
|
||||
${kernel_schedule}
|
||||
@ -855,7 +851,7 @@ ${compile_guard_end}
|
||||
|
||||
@staticmethod
|
||||
def pointerize_if_grouped(operation, layout):
|
||||
return layout if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else layout + "* "
|
||||
return layout if not is_grouped(operation.gemm_kind) else layout + "* "
|
||||
|
||||
@staticmethod
|
||||
def problem_shape(operation):
|
||||
@ -863,7 +859,7 @@ ${compile_guard_end}
|
||||
grouped_gemm_shape_type = "cute::Shape<int,int,int>"
|
||||
grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">"
|
||||
|
||||
return gemm_shape_type if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else grouped_gemm_shape_type
|
||||
return gemm_shape_type if not is_grouped(operation.gemm_kind) else grouped_gemm_shape_type
|
||||
|
||||
def emit(self, operation):
|
||||
_LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)")
|
||||
@ -874,18 +870,12 @@ ${compile_guard_end}
|
||||
opcode_class_main = operation.tile_description.math_instruction.opcode_class
|
||||
opcode_class_epi = opcode_class_main
|
||||
|
||||
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
|
||||
if operation.epilogue_schedule != EpilogueScheduleType.NoSmemWarpSpecialized:
|
||||
opcode_class_epi = OpcodeClass.TensorOp
|
||||
|
||||
|
||||
tile_shape = operation.tile_description.tile_shape
|
||||
instruction_shape = operation.tile_description.math_instruction.instruction_shape
|
||||
cluster_m = operation.tile_description.cluster_shape[0]
|
||||
cluster_n = operation.tile_description.cluster_shape[1]
|
||||
|
||||
tile_shape_main_m, tile_shape_main_n, tile_shape_main_k = tile_shape
|
||||
tile_shape_epi_m, tile_shape_epi_n, tile_shape_epi_k = tile_shape
|
||||
tile_shape_m, tile_shape_n, tile_shape_k = tile_shape
|
||||
|
||||
# account for static/dynamic cluster shapes
|
||||
cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0]
|
||||
@ -902,10 +892,8 @@ ${compile_guard_end}
|
||||
if opcode_class_main in [OpcodeClass.TensorOp
|
||||
, OpcodeClass.BlockScaledTensorOp
|
||||
]:
|
||||
tile_shape_main_m = instruction_shape[0]
|
||||
tile_shape_main_n = instruction_shape[1]
|
||||
tile_shape_epi_m = cta_m
|
||||
tile_shape_epi_n = cta_n
|
||||
tile_shape_m = instruction_shape[0]
|
||||
tile_shape_n = instruction_shape[1]
|
||||
|
||||
|
||||
# stage count set to zero indicates builder automatic stage selection
|
||||
@ -930,35 +918,36 @@ ${compile_guard_end}
|
||||
}
|
||||
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
|
||||
|
||||
if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void:
|
||||
if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void:
|
||||
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
|
||||
|
||||
|
||||
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor.emit_declaration()
|
||||
|
||||
if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void:
|
||||
|
||||
if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void:
|
||||
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
|
||||
|
||||
|
||||
#
|
||||
# Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
|
||||
element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
|
||||
element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
|
||||
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
|
||||
is_no_smem_epilogue = operation.epilogue_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
|
||||
|
||||
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100:
|
||||
is_no_smem_epilogue = operation.epilogue_schedule in [EpilogueScheduleType.NoSmemWarpSpecialized1Sm, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
|
||||
grouped = is_grouped(operation.gemm_kind)
|
||||
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped):
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]
|
||||
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped):
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
|
||||
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
|
||||
element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>'
|
||||
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
|
||||
|
||||
|
||||
|
||||
operation_name_str = operation.procedural_name()
|
||||
layout_a_str = LayoutTag[instance_layout_A]
|
||||
@ -1041,12 +1030,9 @@ using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape(
|
||||
'opcode_class_main': OpcodeClassTag[opcode_class_main],
|
||||
'opcode_class_epi': OpcodeClassTag[opcode_class_epi],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'tile_shape_epi_m': str(tile_shape_epi_m),
|
||||
'tile_shape_epi_n': str(tile_shape_epi_n),
|
||||
'tile_shape_epi_k': str(tile_shape_epi_k),
|
||||
'tile_shape_main_m': str(tile_shape_main_m),
|
||||
'tile_shape_main_n': str(tile_shape_main_n),
|
||||
'tile_shape_main_k': str(tile_shape_main_k),
|
||||
'tile_shape_m': str(tile_shape_m),
|
||||
'tile_shape_n': str(tile_shape_n),
|
||||
'tile_shape_k': str(tile_shape_k),
|
||||
'cluster_shape_m': 'cute::_' + str(operation.tile_description.cluster_shape[0]) if operation.tile_description.cluster_shape[0] > 0 else "int",
|
||||
'cluster_shape_n': 'cute::_' + str(operation.tile_description.cluster_shape[1]) if operation.tile_description.cluster_shape[1] > 0 else "int",
|
||||
'cluster_shape_k': 'cute::_' + str(operation.tile_description.cluster_shape[2]) if operation.tile_description.cluster_shape[2] > 0 else "int",
|
||||
@ -1396,7 +1382,8 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
|
||||
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
|
||||
GemmKind.Grouped: EmitGemmGroupedInstance,
|
||||
GemmKind.GroupedGemmUniversal3x: EmitGemmUniversal3xInstance,
|
||||
GemmKind.GroupedUniversal3x: EmitGemmUniversal3xInstance,
|
||||
GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance,
|
||||
}
|
||||
|
||||
self.gemm_kind_wrappers = {
|
||||
@ -1409,7 +1396,8 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
|
||||
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation',
|
||||
GemmKind.Grouped: 'GemmGroupedOperation',
|
||||
GemmKind.GroupedGemmUniversal3x: 'GroupedGemmUniversal3xOperation'
|
||||
GemmKind.GroupedUniversal3x: 'GroupedGemmUniversal3xOperation',
|
||||
GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation',
|
||||
}
|
||||
|
||||
self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
|
||||
|
||||
@ -217,8 +217,7 @@ def CreateGemmUniversal3xOperator(
|
||||
gemm_op_extra_args["ScaleFactorB"] = data_type["sf_type"]
|
||||
gemm_op_extra_args["ScaleFactorD"] = { "tensor": TensorDescription(data_type["sfd_type"]["type"], data_type["sfd_type"]["layout"]),
|
||||
"vector_size" : data_type["sfd_type"]["vector_size"]}
|
||||
gemm_kind = GemmKind.BlockScaledUniversal3x
|
||||
|
||||
assert is_block_scaled(gemm_kind)
|
||||
|
||||
A_dtype = data_type["a_type"]
|
||||
B_dtype = data_type["b_type"]
|
||||
@ -254,9 +253,6 @@ def CreateGemmUniversal3xOperator(
|
||||
|
||||
return operations
|
||||
|
||||
def is_grouped(gemm_kind):
|
||||
return gemm_kind == GemmKind.GroupedGemmUniversal3x
|
||||
|
||||
# Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts
|
||||
def CreateSparseGemmUniversal3xOperator(
|
||||
manifest, layouts, tile_descriptions, data_types,
|
||||
@ -6654,11 +6650,13 @@ def get_tma_alignment_elt(data_type : DataType, is_f8f6f4 : bool = True ) -> int
|
||||
|
||||
sm100_cluster_shape_1sm = [
|
||||
[4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
sm100_cluster_shape_2sm = [
|
||||
# cluster_m % 2 == 0 for 2sm
|
||||
[4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
@ -6718,6 +6716,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
]
|
||||
|
||||
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
@ -6765,6 +6764,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
]
|
||||
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
@ -7517,8 +7517,227 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
|
||||
|
||||
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
# SM100 MMA with mixed F4/F6/F8 inputs + without block scale
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
||||
return
|
||||
|
||||
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version):
|
||||
# layouts for ABC and their alignments.
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
instruction_sizes_1sm = [
|
||||
# [64, 128, 32],
|
||||
[128, 128, 32],
|
||||
# [64, 256, 32],
|
||||
[128, 256, 32],
|
||||
]
|
||||
|
||||
instruction_sizes_2sm = [
|
||||
# [128, 128, 32],
|
||||
# [128, 256, 32],
|
||||
[256, 128, 32],
|
||||
[256, 256, 32],
|
||||
]
|
||||
|
||||
ab_types = [
|
||||
DataType.f4, DataType.f6, DataType.f8,
|
||||
DataType.e2m1, DataType.e3m2, DataType.e4m3,
|
||||
]
|
||||
|
||||
acc_types = [ DataType.f32 ]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default, TileSchedulerType.StreamK
|
||||
]
|
||||
|
||||
min_cc = 100
|
||||
max_cc = 130
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = []
|
||||
|
||||
is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8)
|
||||
|
||||
# Usage:
|
||||
|
||||
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types):
|
||||
is_runtime_datatype_a = is_runtime_datatype(a_type)
|
||||
is_runtime_datatype_b = is_runtime_datatype(b_type)
|
||||
|
||||
# A/B datatypes should be both static or dynamic
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_1sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
a_type, b_type, acc_type,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add)
|
||||
)
|
||||
|
||||
math_instructions_2sm = []
|
||||
|
||||
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types):
|
||||
is_runtime_datatype_a = is_runtime_datatype(a_type)
|
||||
is_runtime_datatype_b = is_runtime_datatype(b_type)
|
||||
|
||||
# A/B datatypes should be both static or dynamic
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_2sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
a_type, b_type, acc_type,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add)
|
||||
)
|
||||
|
||||
cluster_shapes_1sm = [
|
||||
# [1,2,1],
|
||||
[2,1,1],
|
||||
[1,1,1],
|
||||
# [1,4,1],
|
||||
[4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_1sm:
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_1sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_1sm[1],
|
||||
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
kernel_data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.f32,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.e5m2,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
}
|
||||
]
|
||||
|
||||
for kernel_data_type in kernel_data_types:
|
||||
# Filter out some kernel
|
||||
if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\
|
||||
( kernel_data_type["d_type"] == DataType.e5m2 ):
|
||||
continue
|
||||
|
||||
# Update layout alignment
|
||||
# alignment for d might be different for each kernel_data_type
|
||||
layouts_copy = copy.deepcopy(layouts)
|
||||
for layout in layouts_copy:
|
||||
# alignment for a
|
||||
layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"])
|
||||
# alignment for b
|
||||
layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"])
|
||||
# alignment for d
|
||||
layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"])
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type],
|
||||
[[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], tile_schedulers=tile_schedulers)
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1],
|
||||
# [2,2,1],
|
||||
# [2,4,1],
|
||||
# [4,1,1],
|
||||
# [4,2,1],
|
||||
[4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_2sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_2sm[1],
|
||||
math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
kernel_data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.f32,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.e5m2,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
}
|
||||
]
|
||||
|
||||
for kernel_data_type in kernel_data_types:
|
||||
# Filter some kernel
|
||||
if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\
|
||||
( kernel_data_type["d_type"] == DataType.e5m2 ):
|
||||
continue
|
||||
|
||||
# Update layout alignment
|
||||
# alignment for d might be different for each kernel_data_type
|
||||
layouts_copy = copy.deepcopy(layouts)
|
||||
for layout in layouts_copy:
|
||||
# alignment for a
|
||||
layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"])
|
||||
# alignment for b
|
||||
layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"])
|
||||
# alignment for d
|
||||
layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"])
|
||||
|
||||
if math_inst.instruction_shape[0] == 128:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type],
|
||||
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], tile_schedulers=tile_schedulers)
|
||||
else:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type],
|
||||
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]], tile_schedulers=tile_schedulers)
|
||||
|
||||
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
|
||||
# SM100 MMA with mixed F4/F6/F8 inputs + block scale
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
@ -7529,7 +7748,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
]
|
||||
|
||||
instruction_sizes_1sm = [
|
||||
[128, 128, 32], [128, 256, 32], # Mixed F4/F6/F8 block scaled only supports M=128 for 1SM cases
|
||||
[128, 128, 32], [128, 256, 32], # Block scaled kernels only support M=128 for 1SM cases
|
||||
]
|
||||
|
||||
instruction_sizes_2sm = [
|
||||
@ -7670,8 +7889,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
for data_type in data_types:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]]
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"])
|
||||
)
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1],
|
||||
@ -7766,21 +7984,21 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
if math_inst.instruction_shape[0] == 128:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
|
||||
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]]
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"])
|
||||
)
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
else:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
|
||||
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]]
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"])
|
||||
)
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
|
||||
|
||||
|
||||
def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version):
|
||||
def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
|
||||
# SM100 MMA with F4 + block scale
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
|
||||
grouped = is_grouped(gemm_kind)
|
||||
|
||||
# layouts for ABC and their alignments.
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]],
|
||||
@ -7805,7 +8023,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
def tile_schedulers(sfdtype):
|
||||
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
|
||||
# the epilogue is the traditional linear combination, for which we already have tests with stream-K.
|
||||
if sfdtype["type"] == DataType.void:
|
||||
if sfdtype["type"] == DataType.void or grouped:
|
||||
return [TileSchedulerType.Default]
|
||||
else:
|
||||
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
|
||||
@ -7826,6 +8044,10 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
# grouped GEMM does not support runtime data type yet
|
||||
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_1sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
@ -7853,6 +8075,10 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
# grouped GEMM does not support runtime data type yet
|
||||
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_2sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
@ -7972,15 +8198,21 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
for data_type in data_types:
|
||||
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
|
||||
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
|
||||
# E2M1 x E2M1, vector size 32, E8
|
||||
# E2M1 x E2M1, vector size 16, UE4M3
|
||||
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
|
||||
nvfp4_schedule = [KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]
|
||||
fp4_schedule = [KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]
|
||||
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
|
||||
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped)
|
||||
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped)
|
||||
|
||||
nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule]
|
||||
fp4_schedule = [fp4_kernel_schedule, epi_schedule]
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule]
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind
|
||||
)
|
||||
if isFp4:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule]
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind
|
||||
)
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
@ -8085,18 +8317,20 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
for data_type in data_types:
|
||||
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
|
||||
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
|
||||
# E2M1 x E2M1, vector size 32, E8
|
||||
# E2M1 x E2M1, vector size 32, E8
|
||||
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
|
||||
|
||||
nvfp4_schedule = [KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]
|
||||
fp4_schedule = [KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]
|
||||
epi_schedule = EpilogueScheduleType.ScheduleAuto if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm
|
||||
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped)
|
||||
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped)
|
||||
|
||||
nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule]
|
||||
fp4_schedule = [fp4_kernel_schedule, epi_schedule]
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule]
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
|
||||
)
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
if isFp4:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule]
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
|
||||
)
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
|
||||
|
||||
|
||||
@ -8139,6 +8373,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
MathOperation.multiply_add)]
|
||||
|
||||
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
@ -8237,6 +8472,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
]
|
||||
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
@ -8353,6 +8589,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [1,1,1], [1,4,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
@ -8386,6 +8623,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
@ -8431,6 +8669,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [1,1,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
@ -8498,6 +8737,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
@ -8554,6 +8794,125 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers)
|
||||
|
||||
|
||||
def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
||||
return
|
||||
|
||||
# layouts for ABC and their alignments.
|
||||
layouts = [
|
||||
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]],
|
||||
[[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]],
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]],
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]],
|
||||
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
min_cc = 100
|
||||
max_cc = 130
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = [
|
||||
MathInstruction(
|
||||
[128, 256, 32],
|
||||
DataType.e4m3, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add)]
|
||||
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [2,1,1], [1,1,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.StreamK,
|
||||
]
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_1sm:
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_1sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_1sm[1],
|
||||
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.f16,
|
||||
"d_type" : DataType.e4m3,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
}]
|
||||
|
||||
# Set alignment d based on Destination format.
|
||||
for layout in layouts:
|
||||
layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]]
|
||||
|
||||
for data_type in data_types:
|
||||
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
|
||||
( data_type["d_type"] == DataType.e5m2 ):
|
||||
continue
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]],
|
||||
tile_schedulers=tile_schedulers)
|
||||
|
||||
# 2xSM MMA kernels
|
||||
math_instructions_2sm = [
|
||||
MathInstruction(
|
||||
[256, 256, 32],
|
||||
DataType.e4m3, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
]
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_2sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_2sm[1],
|
||||
math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.f16,
|
||||
"d_type" : DataType.e4m3,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
}]
|
||||
|
||||
# Set alignment d based on Destination format.
|
||||
for layout in layouts:
|
||||
layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]]
|
||||
|
||||
for data_type in data_types:
|
||||
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
|
||||
( data_type["d_type"] == DataType.e5m2 ):
|
||||
continue
|
||||
|
||||
if math_inst.instruction_shape[0] == 128:
|
||||
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm
|
||||
else:
|
||||
epi_schedule = EpilogueScheduleType.ScheduleAuto
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers)
|
||||
|
||||
def GenerateSM100(manifest, cuda_version):
|
||||
#
|
||||
@ -8570,13 +8929,19 @@ def GenerateSM100(manifest, cuda_version):
|
||||
|
||||
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version)
|
||||
# grouped GEMM
|
||||
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
|
||||
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
|
||||
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
|
||||
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
|
||||
|
||||
GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version)
|
||||
|
||||
# StreamK is included in regular generation
|
||||
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version)
|
||||
#
|
||||
# Block Scaled Gemm
|
||||
#
|
||||
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
@ -8955,8 +9320,8 @@ def GenerateSM90(manifest, cuda_version):
|
||||
GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
|
||||
GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
|
||||
GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
|
||||
GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
|
||||
GenerateSM90_TensorOp_1684_complex(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version)
|
||||
|
||||
@ -321,6 +321,12 @@ def is_complex(data_type):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_block_scaled(gemm_kind):
|
||||
return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x)
|
||||
|
||||
def is_grouped(gemm_kind):
|
||||
return gemm_kind in (GemmKind.GroupedUniversal3x, GemmKind.GroupedBlockScaledUniversal3x)
|
||||
|
||||
#
|
||||
def get_complex_from_real(real_type):
|
||||
for r, c in RealComplexBijection:
|
||||
@ -482,23 +488,32 @@ class KernelScheduleType(enum.Enum):
|
||||
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
||||
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
||||
ImplicitTmaWarpSpecializedSm90 = enum_auto()
|
||||
|
||||
|
||||
TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
PtrArrayTmaWarpSpecialized1SmBlockScaledSm100 = enum_auto()
|
||||
PtrArrayTmaWarpSpecialized2SmBlockScaledSm100 = enum_auto()
|
||||
PtrArrayNvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
PtrArrayNvf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
PtrArrayMxf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
PtrArrayMxf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
|
||||
Mxf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
Mxf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative = enum_auto()
|
||||
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong = enum_auto()
|
||||
@ -519,7 +534,7 @@ KernelScheduleTag = {
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum',
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum',
|
||||
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90',
|
||||
|
||||
|
||||
KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100',
|
||||
KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100',
|
||||
|
||||
@ -530,16 +545,25 @@ KernelScheduleTag = {
|
||||
KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100',
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100',
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100',
|
||||
|
||||
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100',
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
|
||||
|
||||
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
|
||||
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
|
||||
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100",
|
||||
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100",
|
||||
KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100",
|
||||
KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100",
|
||||
KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100",
|
||||
KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100",
|
||||
}
|
||||
|
||||
#
|
||||
@ -568,16 +592,25 @@ KernelScheduleSuffixes = {
|
||||
KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm',
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm',
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm',
|
||||
|
||||
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
|
||||
|
||||
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
|
||||
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm',
|
||||
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
||||
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
|
||||
KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
|
||||
KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
|
||||
KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
|
||||
KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
|
||||
}
|
||||
|
||||
class EpilogueScheduleType(enum.Enum):
|
||||
@ -585,6 +618,10 @@ class EpilogueScheduleType(enum.Enum):
|
||||
EpilogueTransposed = enum_auto()
|
||||
NoSmemWarpSpecialized = enum_auto()
|
||||
PtrArrayNoSmemWarpSpecialized = enum_auto()
|
||||
NoSmemWarpSpecialized1Sm = enum_auto()
|
||||
NoSmemWarpSpecialized2Sm = enum_auto()
|
||||
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
|
||||
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
|
||||
TmaWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecializedCooperative = enum_auto()
|
||||
TmaWarpSpecialized1Sm = enum_auto()
|
||||
@ -600,6 +637,10 @@ EpilogueScheduleTag = {
|
||||
EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm',
|
||||
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
|
||||
@ -616,6 +657,10 @@ EpilogueScheduleSuffixes = {
|
||||
EpilogueScheduleType.EpilogueTransposed: '',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
||||
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
|
||||
@ -636,6 +681,23 @@ EpilogueFunctor3xTag = {
|
||||
EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor',
|
||||
}
|
||||
|
||||
def to_grouped_schedule(schedule, grouped):
|
||||
if not grouped:
|
||||
return schedule
|
||||
|
||||
group_schedule_map = {
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100,
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100,
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100,
|
||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
|
||||
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
|
||||
}
|
||||
|
||||
return group_schedule_map[schedule]
|
||||
|
||||
class TileSchedulerType(enum.Enum):
|
||||
Default = enum_auto()
|
||||
Persistent = enum_auto()
|
||||
@ -817,7 +879,8 @@ class GemmKind(enum.Enum):
|
||||
PlanarComplexArray = enum_auto()
|
||||
Grouped = enum_auto()
|
||||
BlockScaledUniversal3x = enum_auto()
|
||||
GroupedGemmUniversal3x = enum_auto()
|
||||
GroupedUniversal3x = enum_auto()
|
||||
GroupedBlockScaledUniversal3x = enum_auto()
|
||||
|
||||
#
|
||||
GemmKindNames = {
|
||||
@ -830,7 +893,8 @@ GemmKindNames = {
|
||||
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
|
||||
GemmKind.Grouped: "gemm_grouped",
|
||||
GemmKind.BlockScaledUniversal3x: "gemm_block_scaled",
|
||||
GemmKind.GroupedGemmUniversal3x: "gemm_grouped",
|
||||
GemmKind.GroupedUniversal3x: "gemm_grouped",
|
||||
GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped_block_scaled"
|
||||
}
|
||||
|
||||
#
|
||||
|
||||
@ -489,7 +489,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
if is_fp32 and (is_tn or is_nn) and (cta_n % cta_k != 0):
|
||||
return [], []
|
||||
|
||||
grouped = gemm_kind == GemmKind.GroupedGemmUniversal3x
|
||||
grouped = is_grouped(gemm_kind)
|
||||
if grouped:
|
||||
# the following cases are unsupported by grouped GEMM
|
||||
if not is_aligned:
|
||||
|
||||
Reference in New Issue
Block a user