CUTLASS 3.5.1 (#1623)

* CUTLASS 3.5.1

* updates, optimizations, fixes
This commit is contained in:
Vijay Thakkar
2024-07-29 08:46:24 -04:00
committed by GitHub
parent 56b46e2d13
commit be60a0b272
312 changed files with 19793 additions and 6775 deletions

View File

@ -121,7 +121,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '3.5.0'
this.__version__ = '3.5.1'
from cutlass.backend import create_memory_pool
from cutlass.emit.pytorch import pytorch

View File

@ -154,20 +154,6 @@ using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
class Sm90RowBroadcastImpl(RowBroadcastImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::RowBroadcastDescriptor<EpilogueDescriptor, {DataTypeTag[self.element]}>;\n"
@property
def type_decl(self):
"""
@ -176,22 +162,14 @@ class Sm90RowBroadcastImpl(RowBroadcastImpl):
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
{self.descriptor}::Stages, typename EpilogueDescriptor::TileShape,
typename {self.descriptor}::Element, {self.stride_mnl}
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
{self.stride_mnl}
>;
"""
return self._type_decl
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
stages = (stages_c + epi_tiles - 1) // epi_tiles + 1
return (DataTypeSize[self.element] * cta_tile_mnk[1] * stages // 8, 16)
class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):

View File

@ -150,7 +150,7 @@ class Conv2dOperation:
else:
group_conv_name = ""
if self.stride_support == StrideSupport.Unity:
if self.stride_support == StrideSupport.Unity and self.conv_kind == ConvKind.Dgrad:
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}"
else:
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}"

View File

@ -69,8 +69,8 @@ using ${operation_name}_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
${arch},
${opcode_class_epi},
${tile_shape}, // tile shape
${cluster_shape}, // cluster shape
${output_cta_tile_shape}, // output cta tile shape
${cluster_shape}, // cluster shape
${epi_tile_mn},
${element_accumulator},
${element_compute},
@ -88,8 +88,8 @@ using ${operation_name}_mainloop =
${element_a}, ${layout_a}, 128 / cute::sizeof_bits_v<${element_a}>,
${element_b}, ${layout_b}, 128 / cute::sizeof_bits_v<${element_b}>,
${element_accumulator},
${tile_shape}, // tile shape
${cluster_shape}, // cluster shape
${mma_tile_shape}, // mma tile shape
${cluster_shape}, // cluster shape
${stages},
${kernel_schedule}
>::CollectiveOp;
@ -106,30 +106,54 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
def arch_number_to_type(self, arch: int) -> str:
return f"cutlass::arch::Sm{arch}"
def tile_shape(self, operation) -> str:
def output_cta_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
# For all three kinds of convolutions, the tile shape's K mode
# differs from GEMM in that needs to be wrapped in a Shape.
# For Wgrad convolutions specifically,
# the N tile shape also needs to be wrapped in a Shape.
m_template = 'cute::_${tile_shape_m}'
m_template = 'cute::_${cta_m}'
if operation.conv_kind == ConvKind.Wgrad:
n_template = 'cute::Shape<cute::_${tile_shape_n}>'
n_template = 'cute::Shape<cute::_${cta_n}>'
else:
n_template = 'cute::_${tile_shape_n}'
k_template = 'cute::Shape<cute::_${tile_shape_k}>'
n_template = 'cute::_${cta_n}'
k_template = 'cute::Shape<cute::_${cta_k}>'
tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
output_cta_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
values = {
'tile_shape_m': operation.tile_description.tile_shape[0],
'tile_shape_n': operation.tile_description.tile_shape[1],
'tile_shape_k': operation.tile_description.tile_shape[2]
'cta_m': cta_m,
'cta_n': cta_n,
'cta_k': cta_k
}
return Template(tile_shape_template).substitute(values)
return Template(output_cta_tile_shape_template).substitute(values)
def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
mma_m = cta_m
mma_n = cta_n
mma_k = cta_k
# For all three kinds of convolutions, the tile shape's K mode
# differs from GEMM in that needs to be wrapped in a Shape.
# For Wgrad convolutions specifically,
# the N tile shape also needs to be wrapped in a Shape.
m_template = 'cute::_${mma_m}'
if operation.conv_kind == ConvKind.Wgrad:
n_template = 'cute::Shape<cute::_${mma_n}>'
else:
n_template = 'cute::_${mma_n}'
k_template = 'cute::Shape<cute::_${mma_k}>'
mma_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
values = {
'mma_m': mma_m,
'mma_n': mma_n,
'mma_k': mma_k
}
return Template(mma_tile_shape_template).substitute(values)
def cluster_shape(self, operation) -> str:
m_template = 'cute::_${cluster_shape_m}'
n_template = 'cute::_${cluster_shape_n}'
k_template = 'cute::_${cluster_shape_k}'
m_template = 'cute::_${cluster_shape_m}' if operation.tile_description.cluster_shape[0] > 0 else 'int(0)'
n_template = 'cute::_${cluster_shape_n}' if operation.tile_description.cluster_shape[1] > 0 else 'int(0)'
k_template = 'cute::_${cluster_shape_k}' if operation.tile_description.cluster_shape[2] > 0 else 'int(0)'
cluster_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
values = {
'cluster_shape_m': operation.tile_description.cluster_shape[0],
@ -159,6 +183,10 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
opcode_class_epi = opcode_class_main
tile_shape = operation.tile_description.tile_shape
cluster_m = operation.tile_description.cluster_shape[0]
cluster_n = operation.tile_description.cluster_shape[1]
cta_m, cta_n, cta_k = tile_shape
warp_count = operation.tile_description.warp_count
epilogue_schedule = EpilogueScheduleTag[operation.epilogue_schedule]
@ -189,19 +217,20 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
'element_d': DataTypeTag[operation.D.element],
'layout_d': LayoutTag[operation.D.layout],
'align_d': int(operation.D.alignment),
'element_accumulator': DataTypeTag[operation.accumulator_type()],
'opcode_class': opcode_class,
'arch': self.arch_number_to_type(operation.arch),
'tile_shape': self.tile_shape(operation),
'cluster_shape': self.cluster_shape(operation),
'opcode_class_epi': opcode_class_epi,
'opcode_class_main': opcode_class_main,
'epi_tile_mn': epi_tile_mn,
'stages': self.stage_count(operation),
'kernel_schedule': kernel_schedule,
'epilogue_schedule': epilogue_schedule,
'tile_scheduler': tile_scheduler,
'element_compute': DataTypeTag[operation.element_compute]
'element_accumulator': DataTypeTag[operation.accumulator_type()],
'opcode_class': opcode_class,
'arch': self.arch_number_to_type(operation.arch),
'output_cta_tile_shape': self.output_cta_tile_shape(operation, cta_m, cta_n, cta_k),
'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k),
'cluster_shape': self.cluster_shape(operation),
'opcode_class_epi': opcode_class_epi,
'opcode_class_main': opcode_class_main,
'epi_tile_mn': epi_tile_mn,
'stages': self.stage_count(operation),
'kernel_schedule': kernel_schedule,
'epilogue_schedule': epilogue_schedule,
'tile_scheduler': tile_scheduler,
'element_compute': DataTypeTag[operation.element_compute]
}
return Template(self.template).substitute(values)

View File

@ -178,16 +178,28 @@ class GemmOperation:
if self.is_complex():
extended_name = "${core_name}"
else:
# e.g. f16_f16_f32_void_f32 kernel
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${element_c}_${core_name}_${element_a}"
if self.is_mixed_input():
extended_name += "_${element_b}"
# e.g. f32_f32_f32_void_f32 kernel
elif self.C.element != self.tile_description.math_instruction.element_accumulator and \
self.A.element == self.tile_description.math_instruction.element_accumulator:
extended_name = "${element_c}_${core_name}"
if self.is_mixed_input():
extended_name += "_${element_b}"
# e.g. f16_f16_f32_f32_f32 kernel
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${core_name}_${element_a}"
if self.is_mixed_input():
extended_name += "_${element_b}"
# e.g. f32_f32_f32_f32_f32 kernel
else:
extended_name = "${core_name}"

View File

@ -36,13 +36,13 @@ Utilities for enumerating CUTLASS library kernels
import argparse
import enum
from itertools import product
from itertools import chain, product
import logging
import os.path
import shutil
import sys
import copy
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple
_LOGGER = logging.getLogger(__name__)
@ -513,7 +513,7 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme
new_operations = [
# None grouped kernel
Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_),
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_),
]
# Instance group conv kernel
@ -521,12 +521,12 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme
tile.minimum_compute_capability >= 80:
# SingleGroup kernel
new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup))
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup))
# Analytic iterator supports MultipleGroup mode
if iterator_algorithm == IteratorAlgorithm.Analytic:
new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup))
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup))
for new_operation in new_operations:
manifest.append(new_operation)
@ -884,7 +884,7 @@ class ConvOperation3x:
prefix = ''
if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
prefix = 'g'
return prefix + ShortDataTypeNames[self.accumulator_type()]
return prefix + DataTypeNames[self.accumulator_type()]
def is_tensor_op(self):
tensor_ops = [
@ -1054,8 +1054,11 @@ def CreateConvOperator3x(manifest: Manifest,
log_debug_line(f'conv_kind: {conv_kind}', log_indent_level)
for triple in dims_and_alignments:
spatial_dimensionality = None # to be determined by loop below
assert(isinstance(triple, tuple) or isinstance(triple, list))
assert(len(triple) == 3)
spatial_dimensionality = None # to be determined by loop below
for entry in triple: # [A, B, C]
assert(len(entry) == 2)
[dim, alignment] = entry
@ -6631,85 +6634,352 @@ def GenerateSM90_Conv3x(manifest, cuda_version,
minimum_compute_capability = 90
maximum_compute_capability = 90
spatial_dims = [2, 3]
spatial_dims = (2, 3)
def make_dims_and_alignments_triple(dim: int):
byte_alignment_required_by_tma = 16
return ((dim, byte_alignment_required_by_tma), # A
(dim, byte_alignment_required_by_tma), # B
(dim, byte_alignment_required_by_tma)) # C
dims_and_alignments = [make_dims_and_alignments_triple(dim) for dim in spatial_dims]
# This function only generates kernels that use TMA.
byte_alignment_required_by_tma = 16
tma_byte_alignments = {
'A': byte_alignment_required_by_tma,
'B': byte_alignment_required_by_tma,
'C': byte_alignment_required_by_tma,
}
def make_math_instruction(data_types: Tuple[DataType, DataType, DataType],
instruction_shape: Tuple[int, int, int]) -> MathInstruction:
# For tuples of one element, the element needs to end with comma.
all_byte_alignments = (
tma_byte_alignments,
)
# MMA shapes (MMA_M, MMA_N, MMA_K):
#
# Different hardware MMA instructions may have different MMA shapes.
# This function may generate kernels with different MMA shapes for
# different data types, either because the hardware only supports
# certain shapes for certain types, or for performance reasons
# (CUTLASS doesn't need to generate all valid kernels for the
# profiler library, just the best-performing ones).
#
# The kernel names refer to tile shapes (TILE_M, TILE_N, TILE_K)
# instead of MMA shapes. For SM >= 90 kernels, TILE_K = 4 * MMA_K,
# where 4, the "number of MMA instructions per tile," is determined
# through some combination of modeling and experiment.
#
# For performance on sm90, generally CUTLASS generates 64x128
# instead of 128x64.
mma_64x64x16 = ( 64, 64, 16)
mma_64x64x8 = ( 64, 64, 8)
num_mma_per_tile = 4
# Cluster shapes (1, 1, 1) and (2, 2, 1) are valid,
# but not included, because they tend not to perform as well.
cluster_shapes = (
(2, 1, 1),
(1, 2, 1),
)
fp16 = DataType.f16
bf16 = DataType.bf16
fp32 = DataType.f32
s8 = DataType.s8
s32 = DataType.s32
# When generating kernels, the usual way is to specify 4 types,
# (A, B, Acc, C/D). Tests instead have 5 types,
# (ElementAct, ElementFlt, ElementOut, ElementAcc, ElementCompute),
# where ElementCompute is also called 'epi_type',
# and corresponds to the type of epilogue activations.
# This script maps tests' 5 types to 4 types
# by making ElementCompute the same as ElementOut.
fp16_fp32_fp16_fp32 = {
'a_type': fp16, # ElementAct(ivation)
'b_type': fp16, # ElementF(i)lt(er)
'c_type': fp32, # ElementAcc
'd_type': fp32, # ElementOut (used only by CollectiveEpilogue)
'acc_type': fp16, # ElementAcc
'epi_type': fp32, # ElementCompute (used only by CollectiveEpilogue)
}
fp16_fp32_fp32_fp32 = {
'a_type': fp16,
'b_type': fp16,
'c_type': fp32,
'd_type': fp32,
'acc_type': fp32,
'epi_type': fp32,
}
fp32_fp32_fp32_fp32 = {
'a_type': fp32,
'b_type': fp32,
'c_type': fp32,
'd_type': fp32,
'acc_type': fp32,
'epi_type': fp32,
}
s8_s32_s32_s32 = {
'a_type': s8,
'b_type': s8,
'c_type': s32,
'd_type': s32,
'acc_type': s32,
'epi_type': s32,
}
# Other NVIDIA libraries may have the habit of specifying data types like this.
bf16bf16_bf16f32_f32 = {
'a_type': bf16,
'b_type': bf16,
'c_type': fp32,
'd_type': fp32,
'acc_type': fp32,
'epi_type': fp32,
}
f16f16_f16f16_f16 = {
'a_type': fp16,
'b_type': fp16,
'c_type': fp16,
'd_type': fp16,
'acc_type': fp16,
'epi_type': fp16,
}
f16f16_f16f32_f32 = {
'a_type': fp16,
'b_type': fp16,
'c_type': fp16,
'd_type': fp16,
'acc_type': fp32,
'epi_type': fp32,
}
f32f32_tf32f32_f32 = fp32_fp32_fp32_fp32
i8i8_i8i32_f32 = {
'a_type': s8,
'b_type': s8,
'c_type': s32,
'd_type': s32,
'acc_type': s32,
'epi_type': s32,
}
# Each element in the outermost iterable is one combination of
#
# (ConvKind, spatial_dimension, data_types, byte_alignments, mma_sizes, cluster_sizes)
#
# for which to generate a kernel. spatial_dimension is the spatial
# dimension of the convolution: either 1, 2, or 3. byte_alignments
# is a triple of required minimum byte alignments for A, B, and C.
#
# Note that itertools functions produce a single-pass generator.
# The code doesn't need a multipass iterable, but if one did, one
# could call `tuple` or `list` on the generator.
#
# While this happens to use the same cluster sizes for each element,
# the code doesn't require that. Different convolution kinds, data
# types, or mma sizes might have different optimal cluster sizes.
combinations_of_parameters = chain(
# The following are all the kernels exercised in the unit tests.
# Please try to keep in sync with the unit tests.
product(
(
ConvKind.Fprop,
),
spatial_dims,
(
fp16_fp32_fp16_fp32,
fp16_fp32_fp32_fp32,
s8_s32_s32_s32,
),
all_byte_alignments,
(
mma_64x64x16,
),
cluster_shapes
),
product(
(
ConvKind.Fprop,
),
spatial_dims,
(
fp32_fp32_fp32_fp32,
),
all_byte_alignments,
(
mma_64x64x8,
),
cluster_shapes
),
product(
(
ConvKind.Dgrad,
),
spatial_dims,
(
fp16_fp32_fp16_fp32,
fp16_fp32_fp32_fp32,
),
all_byte_alignments,
(
mma_64x64x16,
),
cluster_shapes
),
# Kernels not necessarily in the unit tests, but used elsewhere
# and thus useful to have generated for profiling. They may
# duplicate kernels above. All of them are 2-D. In general,
# CUTLASS prefers 64 x 128 to 128 x 64 on sm90, even if the
# hardware permits 128 x 64.
(
# Fprop
#
# bf16bf16_bf16f32_f32
#
# cluster shape (2, 1, 1)
#
(ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)),
(ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)),
(ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)),
(ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)),
#
# f16f16_f16f16_f16
#
# cluster shape (1, 1, 1)
#
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 64, 8), (1, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 64, 16), (1, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 128, 8), (1, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 128, 16), (1, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 256, 8), (1, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 256, 16), (1, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 128, 8), (1, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 128, 16), (1, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 256, 8), (1, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 256, 16), (1, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 64, 8), (1, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 64, 16), (1, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 128, 8), (1, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 128, 16), (1, 1, 1)),
#
# f16f16_f16f32_f32
#
# cluster shape (2, 1, 1)
#
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 192, 8), (2, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 192, 16), (2, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 96, 8), (2, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 96, 16), (2, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)),
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)),
#
# f32f32_tf32f32_f32
#
# cluster shape (2, 1, 1)
#
(ConvKind.Fprop, 2, f32f32_tf32f32_f32, tma_byte_alignments, (128, 192, 8), (2, 1, 1)),
(ConvKind.Fprop, 2, f32f32_tf32f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)),
(ConvKind.Fprop, 2, f32f32_tf32f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)),
(ConvKind.Fprop, 2, f32f32_tf32f32_f32, tma_byte_alignments, (256, 96, 8), (2, 1, 1)),
#
# i8i8_i8i32_f32
#
# cluster shape (2, 1, 1)
#
(ConvKind.Fprop, 2, i8i8_i8i32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)),
(ConvKind.Fprop, 2, i8i8_i8i32_f32, tma_byte_alignments, (128, 256, 32), (2, 1, 1)),
(ConvKind.Fprop, 2, i8i8_i8i32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)),
(ConvKind.Fprop, 2, i8i8_i8i32_f32, tma_byte_alignments, (256, 128, 32), (2, 1, 1)),
#
# Dgrad
#
# bf16bf16_bf16f32_f32
#
# cluster shape (2, 1, 1)
#
(ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)),
(ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)),
(ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)),
(ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)),
#
# f16f16_f16f16_f16
#
# cluster shape (1, 1, 1)
#
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 64, 8), (1, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 64, 16), (1, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 128, 8), (1, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 128, 16), (1, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 256, 8), (1, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 256, 16), (1, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 128, 8), (1, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 128, 16), (1, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 256, 8), (1, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 256, 16), (1, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 64, 8), (1, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 64, 16), (1, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 128, 8), (1, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 128, 16), (1, 1, 1)),
#
# f16f16_f16f32_f32
#
# cluster shape (2, 1, 1)
#
(ConvKind.Dgrad, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)),
(ConvKind.Dgrad, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)),
),
)
# SM >= 90 kernels don't actually use warp_count, but the
# TileDescription class needs it. The 4 in the default
# warp_count has nothing to do with num_mma_per_tile.
warp_count = [4, 1, 1]
stages = 0 # zero means "deduce the number of stages automatically"
mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecializedSm90
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
schedule_pairs = (
(mainloop_schedule, epilogue_schedule),
)
tile_schedulers = (
TileSchedulerType.Default, # -> void
)
def make_math_instruction(data_types: Dict[str, DataType],
mma_shape: Tuple[int, int, int]) -> MathInstruction:
default_opcode = OpcodeClass.TensorOp
default_math_op = MathOperation.multiply_add
[A_data_type, B_data_type, C_data_type] = data_types
return MathInstruction(
instruction_shape,
A_data_type, B_data_type, C_data_type,
mma_shape,
data_types['a_type'], data_types['b_type'], data_types['c_type'],
default_opcode,
default_math_op
)
data_types_and_instruction_shapes = [
((DataType.f16, DataType.f16, DataType.f16), (64, 64, 16)),
((DataType.f16, DataType.f16, DataType.f32), (64, 64, 16)),
((DataType.bf16, DataType.bf16, DataType.f32), (64, 64, 16)),
]
math_instructions = map(lambda x: make_math_instruction(*x),
data_types_and_instruction_shapes)
cluster_shapes = [
[2, 1, 1],
[1, 1, 1],
]
conv_kinds = [
ConvKind.Fprop,
ConvKind.Dgrad
]
mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecializedSm90
stages = 0 # zero means "deduce the number of stages automatically"
# tile_descriptions is a 2-level list.
# Each inner list is for each cluster shape.
for math_inst in math_instructions:
tile_descriptions = []
for cluster_shape in cluster_shapes:
tile_shape = [
math_inst.instruction_shape[0],
math_inst.instruction_shape[1],
math_inst.instruction_shape[2] * 4
]
warp_count = [4, 1, 1]
tile_description = TileDescription(
tile_shape, stages, warp_count, math_inst,
minimum_compute_capability, maximum_compute_capability,
cluster_shape)
tile_descriptions.append(tile_description)
# It's typical to get the data types from the math instruction.
data_type = {
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : math_inst.element_accumulator,
"d_type" : math_inst.element_accumulator,
"acc_type" : math_inst.element_accumulator,
"epi_type" : math_inst.element_accumulator
}
for conv_kind in conv_kinds:
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
schedule_pairs = [
(mainloop_schedule, epilogue_schedule)
]
CreateConvOperator3x(manifest,
dims_and_alignments = dims_and_alignments,
tile_descriptions = tile_descriptions,
data_types = data_type,
schedule_pairs = schedule_pairs,
tile_schedulers = [TileSchedulerType.Default], # -> void
conv_kind = conv_kind,
log_indent_level = log_indent_level)
for (conv_kind, spatial_dim, data_types, byte_alignments, mma_shape, cluster_shape) in combinations_of_parameters:
math_inst = make_math_instruction(data_types, mma_shape)
tile_shape = (mma_shape[0], mma_shape[1], num_mma_per_tile * mma_shape[2])
tile_description = TileDescription(tile_shape, stages, warp_count, math_inst,
minimum_compute_capability, maximum_compute_capability, cluster_shape)
assert(isinstance(spatial_dim, int))
assert(isinstance(byte_alignments, dict))
dims_and_alignments = (
(
(spatial_dim, byte_alignments['A']),
(spatial_dim, byte_alignments['B']),
(spatial_dim, byte_alignments['C']),
),
)
CreateConvOperator3x(manifest,
dims_and_alignments = dims_and_alignments,
tile_descriptions = [tile_description],
data_types = data_types,
schedule_pairs = schedule_pairs,
tile_schedulers = tile_schedulers,
conv_kind = conv_kind,
log_indent_level = log_indent_level)
def GenerateSM90(manifest, cuda_version):
GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version)
@ -6738,8 +7008,8 @@ def GenerateSM90(manifest, cuda_version):
def numeric_log_level(log_level: str) -> int:
"""
Converts the string identifier of the log level into the numeric identifier used
in setting the log level
Converts the string identifier of the log level
into the numeric identifier used in setting the log level.
:param x: string representation of log level (e.g., 'INFO', 'DEBUG')
:type x: str
@ -6762,8 +7032,18 @@ def define_parser():
parser.add_argument("--curr-build-dir", default=".", help="CUTLASS current build directory. cmake files will be emitted in this directory")
parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.")
parser.add_argument("--architectures", default='53;60;61;70;75;80;90', help="Target compute architectures")
parser.add_argument("--kernels", default='', help='Comma delimited list to filter kernels by name.')
parser.add_argument("--ignore-kernels", default='', help='Comma delimited list of kernels to exclude from build.')
parser.add_argument("--kernels", default='', help='Comma-delimited list to filter kernels by name. ' +
'Specifying this as \"all\" includes ALL the kernels, ' +
'while not specifying this includes only the default set of kernels.')
parser.add_argument("--ignore-kernels", default='', help='Comma-delimited list of kernels ' +
'to exclude from build. For backwards compatibility reasons, ' +
'this option only takes effect if --kernels is set to a nonempty value.')
parser.add_argument("--exclude-kernels", default='', help='Comma-delimited list of kernels ' +
'to exclude from build. In contrast to --ignore-kernels, ' +
'this option always takes effect, ' +
'whether or not --kernels is set to a nonempty value. ' +
'It also can exclude kernels from the filter file ' +
'(see --kernel-filter-file option below).')
parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose compute capability range is not satisfied by the build target are excluded.')
parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit")
parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file')

View File

@ -506,6 +506,7 @@ class Manifest:
self.operations_enabled = []
self.selected_kernels = []
self.ignore_kernel_names = []
self.exclude_kernel_names = []
self.compute_capabilities = [50,]
self.curr_build_dir = '.'
self.filter_by_cc = True
@ -546,6 +547,7 @@ class Manifest:
self.kernel_names = [x for x in args.kernels.split(',') if x != '']
self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != '']
if args.kernel_filter_file is None:
self.kernel_filter_list = []
@ -612,41 +614,54 @@ class Manifest:
if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled:
return False
name = operation.procedural_name()
# eliminate duplicates
if operation.procedural_name() in self.operations_by_name.keys():
if name in self.operations_by_name.keys():
return False
# Filter based on list of valid substrings
if len(self.kernel_names):
name = operation.procedural_name()
enabled = False
# compare against the include list
for name_substr in self.kernel_names:
if self._filter_string_matches(name_substr, name):
_LOGGER.debug("Kernel {kernel} included due to filter string '{filt}'.".format(
kernel = operation.procedural_name(),
filt = name_substr))
_LOGGER.debug(f"Kernel {name} included due to filter string '{name_substr}'.")
enabled = True
break
else:
_LOGGER.debug(f"Kernel {name} NOT included due to not matching '{name_substr}'.")
# compare against the exclude list
for name_substr in self.ignore_kernel_names:
if self._filter_string_matches(name_substr, name):
_LOGGER.debug("Kernel {kernel} ignored due to filter string '{filt}'.".format(
kernel = operation.procedural_name(),
filt = name_substr))
_LOGGER.debug(f"Kernel {name} ignored due to filter string '{name_substr}'.")
enabled = False
break
else:
_LOGGER.debug(f"Kernel {name} NOT ignored due to not matching '{name_substr}'.")
if len(self.kernel_filter_list) > 0:
if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list):
_LOGGER.debug("Kernel {kernel} matched via kernel filter file.".format(kernel = operation.procedural_name()))
enabled = True
else:
_LOGGER.debug("Kernel {kernel} culled due to no match in kernel filter file.".format(kernel = operation.procedural_name()))
enabled = False
if self.filter_out_kernels(name, self.kernel_filter_list):
_LOGGER.debug(f"Kernel {name} matched via kernel filter file.")
enabled = True
else:
_LOGGER.debug(f"Kernel {name} culled due to no match in kernel filter file.")
enabled = False
# CUTLASS_LIBRARY_IGNORE_KERNELS ("ignore" list) only takes effect
# if CUTLASS_LIBRARY_KERNELS was specified.
# Changing that would break backwards compatibility.
# Thus, CUTLASS has introduced the new CMake option CUTLASS_LIBRARY_EXCLUDE_KERNELS,
# that always takes effect, whether or not CUTLASS_LIBRARY_KERNELS was specified.
for name_substr in self.exclude_kernel_names:
if self._filter_string_matches(name_substr, name):
_LOGGER.debug(f"Kernel {name} excluded due to filter string '{name_substr}'.")
enabled = False
break
else:
_LOGGER.debug(f"Kernel {name} NOT excluded due to not matching '{name_substr}'.")
# TODO: filter based on compute data type
return enabled

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='cutlass_library',
version='3.5.0',
version='3.5.1',
description='CUTLASS library generation scripts',
packages=['cutlass_library']
)

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='pycute',
version='3.5.0',
version='3.5.1',
description='Python implementation of CuTe',
packages=['pycute'],
)