@ -121,7 +121,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '3.5.0'
|
||||
this.__version__ = '3.5.1'
|
||||
|
||||
from cutlass.backend import create_memory_pool
|
||||
from cutlass.emit.pytorch import pytorch
|
||||
|
||||
@ -154,20 +154,6 @@ using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
|
||||
|
||||
|
||||
class Sm90RowBroadcastImpl(RowBroadcastImpl):
|
||||
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::RowBroadcastDescriptor<EpilogueDescriptor, {DataTypeTag[self.element]}>;\n"
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
@ -176,22 +162,14 @@ class Sm90RowBroadcastImpl(RowBroadcastImpl):
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
{self.descriptor}::Stages, typename EpilogueDescriptor::TileShape,
|
||||
typename {self.descriptor}::Element, {self.stride_mnl}
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
stages = (stages_c + epi_tiles - 1) // epi_tiles + 1
|
||||
return (DataTypeSize[self.element] * cta_tile_mnk[1] * stages // 8, 16)
|
||||
|
||||
|
||||
class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):
|
||||
|
||||
|
||||
@ -150,7 +150,7 @@ class Conv2dOperation:
|
||||
else:
|
||||
group_conv_name = ""
|
||||
|
||||
if self.stride_support == StrideSupport.Unity:
|
||||
if self.stride_support == StrideSupport.Unity and self.conv_kind == ConvKind.Dgrad:
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}"
|
||||
else:
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}"
|
||||
|
||||
@ -69,8 +69,8 @@ using ${operation_name}_epilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
${arch},
|
||||
${opcode_class_epi},
|
||||
${tile_shape}, // tile shape
|
||||
${cluster_shape}, // cluster shape
|
||||
${output_cta_tile_shape}, // output cta tile shape
|
||||
${cluster_shape}, // cluster shape
|
||||
${epi_tile_mn},
|
||||
${element_accumulator},
|
||||
${element_compute},
|
||||
@ -88,8 +88,8 @@ using ${operation_name}_mainloop =
|
||||
${element_a}, ${layout_a}, 128 / cute::sizeof_bits_v<${element_a}>,
|
||||
${element_b}, ${layout_b}, 128 / cute::sizeof_bits_v<${element_b}>,
|
||||
${element_accumulator},
|
||||
${tile_shape}, // tile shape
|
||||
${cluster_shape}, // cluster shape
|
||||
${mma_tile_shape}, // mma tile shape
|
||||
${cluster_shape}, // cluster shape
|
||||
${stages},
|
||||
${kernel_schedule}
|
||||
>::CollectiveOp;
|
||||
@ -106,30 +106,54 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
|
||||
def arch_number_to_type(self, arch: int) -> str:
|
||||
return f"cutlass::arch::Sm{arch}"
|
||||
|
||||
def tile_shape(self, operation) -> str:
|
||||
def output_cta_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
|
||||
# For all three kinds of convolutions, the tile shape's K mode
|
||||
# differs from GEMM in that needs to be wrapped in a Shape.
|
||||
# For Wgrad convolutions specifically,
|
||||
# the N tile shape also needs to be wrapped in a Shape.
|
||||
m_template = 'cute::_${tile_shape_m}'
|
||||
m_template = 'cute::_${cta_m}'
|
||||
if operation.conv_kind == ConvKind.Wgrad:
|
||||
n_template = 'cute::Shape<cute::_${tile_shape_n}>'
|
||||
n_template = 'cute::Shape<cute::_${cta_n}>'
|
||||
else:
|
||||
n_template = 'cute::_${tile_shape_n}'
|
||||
k_template = 'cute::Shape<cute::_${tile_shape_k}>'
|
||||
n_template = 'cute::_${cta_n}'
|
||||
k_template = 'cute::Shape<cute::_${cta_k}>'
|
||||
|
||||
tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
|
||||
output_cta_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
|
||||
values = {
|
||||
'tile_shape_m': operation.tile_description.tile_shape[0],
|
||||
'tile_shape_n': operation.tile_description.tile_shape[1],
|
||||
'tile_shape_k': operation.tile_description.tile_shape[2]
|
||||
'cta_m': cta_m,
|
||||
'cta_n': cta_n,
|
||||
'cta_k': cta_k
|
||||
}
|
||||
return Template(tile_shape_template).substitute(values)
|
||||
return Template(output_cta_tile_shape_template).substitute(values)
|
||||
|
||||
def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
|
||||
mma_m = cta_m
|
||||
mma_n = cta_n
|
||||
mma_k = cta_k
|
||||
|
||||
# For all three kinds of convolutions, the tile shape's K mode
|
||||
# differs from GEMM in that needs to be wrapped in a Shape.
|
||||
# For Wgrad convolutions specifically,
|
||||
# the N tile shape also needs to be wrapped in a Shape.
|
||||
m_template = 'cute::_${mma_m}'
|
||||
if operation.conv_kind == ConvKind.Wgrad:
|
||||
n_template = 'cute::Shape<cute::_${mma_n}>'
|
||||
else:
|
||||
n_template = 'cute::_${mma_n}'
|
||||
k_template = 'cute::Shape<cute::_${mma_k}>'
|
||||
|
||||
mma_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
|
||||
values = {
|
||||
'mma_m': mma_m,
|
||||
'mma_n': mma_n,
|
||||
'mma_k': mma_k
|
||||
}
|
||||
return Template(mma_tile_shape_template).substitute(values)
|
||||
|
||||
def cluster_shape(self, operation) -> str:
|
||||
m_template = 'cute::_${cluster_shape_m}'
|
||||
n_template = 'cute::_${cluster_shape_n}'
|
||||
k_template = 'cute::_${cluster_shape_k}'
|
||||
m_template = 'cute::_${cluster_shape_m}' if operation.tile_description.cluster_shape[0] > 0 else 'int(0)'
|
||||
n_template = 'cute::_${cluster_shape_n}' if operation.tile_description.cluster_shape[1] > 0 else 'int(0)'
|
||||
k_template = 'cute::_${cluster_shape_k}' if operation.tile_description.cluster_shape[2] > 0 else 'int(0)'
|
||||
cluster_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
|
||||
values = {
|
||||
'cluster_shape_m': operation.tile_description.cluster_shape[0],
|
||||
@ -159,6 +183,10 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
|
||||
opcode_class_epi = opcode_class_main
|
||||
|
||||
tile_shape = operation.tile_description.tile_shape
|
||||
cluster_m = operation.tile_description.cluster_shape[0]
|
||||
cluster_n = operation.tile_description.cluster_shape[1]
|
||||
|
||||
cta_m, cta_n, cta_k = tile_shape
|
||||
warp_count = operation.tile_description.warp_count
|
||||
epilogue_schedule = EpilogueScheduleTag[operation.epilogue_schedule]
|
||||
|
||||
@ -189,19 +217,20 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
|
||||
'element_d': DataTypeTag[operation.D.element],
|
||||
'layout_d': LayoutTag[operation.D.layout],
|
||||
'align_d': int(operation.D.alignment),
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': opcode_class,
|
||||
'arch': self.arch_number_to_type(operation.arch),
|
||||
'tile_shape': self.tile_shape(operation),
|
||||
'cluster_shape': self.cluster_shape(operation),
|
||||
'opcode_class_epi': opcode_class_epi,
|
||||
'opcode_class_main': opcode_class_main,
|
||||
'epi_tile_mn': epi_tile_mn,
|
||||
'stages': self.stage_count(operation),
|
||||
'kernel_schedule': kernel_schedule,
|
||||
'epilogue_schedule': epilogue_schedule,
|
||||
'tile_scheduler': tile_scheduler,
|
||||
'element_compute': DataTypeTag[operation.element_compute]
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': opcode_class,
|
||||
'arch': self.arch_number_to_type(operation.arch),
|
||||
'output_cta_tile_shape': self.output_cta_tile_shape(operation, cta_m, cta_n, cta_k),
|
||||
'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k),
|
||||
'cluster_shape': self.cluster_shape(operation),
|
||||
'opcode_class_epi': opcode_class_epi,
|
||||
'opcode_class_main': opcode_class_main,
|
||||
'epi_tile_mn': epi_tile_mn,
|
||||
'stages': self.stage_count(operation),
|
||||
'kernel_schedule': kernel_schedule,
|
||||
'epilogue_schedule': epilogue_schedule,
|
||||
'tile_scheduler': tile_scheduler,
|
||||
'element_compute': DataTypeTag[operation.element_compute]
|
||||
}
|
||||
return Template(self.template).substitute(values)
|
||||
|
||||
|
||||
@ -178,16 +178,28 @@ class GemmOperation:
|
||||
if self.is_complex():
|
||||
extended_name = "${core_name}"
|
||||
else:
|
||||
# e.g. f16_f16_f32_void_f32 kernel
|
||||
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${element_c}_${core_name}_${element_a}"
|
||||
if self.is_mixed_input():
|
||||
extended_name += "_${element_b}"
|
||||
|
||||
# e.g. f32_f32_f32_void_f32 kernel
|
||||
elif self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element == self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${element_c}_${core_name}"
|
||||
if self.is_mixed_input():
|
||||
extended_name += "_${element_b}"
|
||||
|
||||
# e.g. f16_f16_f32_f32_f32 kernel
|
||||
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${core_name}_${element_a}"
|
||||
if self.is_mixed_input():
|
||||
extended_name += "_${element_b}"
|
||||
|
||||
# e.g. f32_f32_f32_f32_f32 kernel
|
||||
else:
|
||||
extended_name = "${core_name}"
|
||||
|
||||
|
||||
@ -36,13 +36,13 @@ Utilities for enumerating CUTLASS library kernels
|
||||
|
||||
import argparse
|
||||
import enum
|
||||
from itertools import product
|
||||
from itertools import chain, product
|
||||
import logging
|
||||
import os.path
|
||||
import shutil
|
||||
import sys
|
||||
import copy
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -513,7 +513,7 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme
|
||||
new_operations = [
|
||||
# None grouped kernel
|
||||
Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_),
|
||||
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_),
|
||||
]
|
||||
|
||||
# Instance group conv kernel
|
||||
@ -521,12 +521,12 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme
|
||||
tile.minimum_compute_capability >= 80:
|
||||
# SingleGroup kernel
|
||||
new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup))
|
||||
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup))
|
||||
|
||||
# Analytic iterator supports MultipleGroup mode
|
||||
if iterator_algorithm == IteratorAlgorithm.Analytic:
|
||||
new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup))
|
||||
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup))
|
||||
|
||||
for new_operation in new_operations:
|
||||
manifest.append(new_operation)
|
||||
@ -884,7 +884,7 @@ class ConvOperation3x:
|
||||
prefix = ''
|
||||
if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
|
||||
prefix = 'g'
|
||||
return prefix + ShortDataTypeNames[self.accumulator_type()]
|
||||
return prefix + DataTypeNames[self.accumulator_type()]
|
||||
|
||||
def is_tensor_op(self):
|
||||
tensor_ops = [
|
||||
@ -1054,8 +1054,11 @@ def CreateConvOperator3x(manifest: Manifest,
|
||||
log_debug_line(f'conv_kind: {conv_kind}', log_indent_level)
|
||||
|
||||
for triple in dims_and_alignments:
|
||||
spatial_dimensionality = None # to be determined by loop below
|
||||
assert(isinstance(triple, tuple) or isinstance(triple, list))
|
||||
assert(len(triple) == 3)
|
||||
|
||||
spatial_dimensionality = None # to be determined by loop below
|
||||
|
||||
for entry in triple: # [A, B, C]
|
||||
assert(len(entry) == 2)
|
||||
[dim, alignment] = entry
|
||||
@ -6631,85 +6634,352 @@ def GenerateSM90_Conv3x(manifest, cuda_version,
|
||||
minimum_compute_capability = 90
|
||||
maximum_compute_capability = 90
|
||||
|
||||
spatial_dims = [2, 3]
|
||||
spatial_dims = (2, 3)
|
||||
|
||||
def make_dims_and_alignments_triple(dim: int):
|
||||
byte_alignment_required_by_tma = 16
|
||||
return ((dim, byte_alignment_required_by_tma), # A
|
||||
(dim, byte_alignment_required_by_tma), # B
|
||||
(dim, byte_alignment_required_by_tma)) # C
|
||||
dims_and_alignments = [make_dims_and_alignments_triple(dim) for dim in spatial_dims]
|
||||
# This function only generates kernels that use TMA.
|
||||
byte_alignment_required_by_tma = 16
|
||||
tma_byte_alignments = {
|
||||
'A': byte_alignment_required_by_tma,
|
||||
'B': byte_alignment_required_by_tma,
|
||||
'C': byte_alignment_required_by_tma,
|
||||
}
|
||||
|
||||
def make_math_instruction(data_types: Tuple[DataType, DataType, DataType],
|
||||
instruction_shape: Tuple[int, int, int]) -> MathInstruction:
|
||||
# For tuples of one element, the element needs to end with comma.
|
||||
all_byte_alignments = (
|
||||
tma_byte_alignments,
|
||||
)
|
||||
|
||||
# MMA shapes (MMA_M, MMA_N, MMA_K):
|
||||
#
|
||||
# Different hardware MMA instructions may have different MMA shapes.
|
||||
# This function may generate kernels with different MMA shapes for
|
||||
# different data types, either because the hardware only supports
|
||||
# certain shapes for certain types, or for performance reasons
|
||||
# (CUTLASS doesn't need to generate all valid kernels for the
|
||||
# profiler library, just the best-performing ones).
|
||||
#
|
||||
# The kernel names refer to tile shapes (TILE_M, TILE_N, TILE_K)
|
||||
# instead of MMA shapes. For SM >= 90 kernels, TILE_K = 4 * MMA_K,
|
||||
# where 4, the "number of MMA instructions per tile," is determined
|
||||
# through some combination of modeling and experiment.
|
||||
#
|
||||
# For performance on sm90, generally CUTLASS generates 64x128
|
||||
# instead of 128x64.
|
||||
mma_64x64x16 = ( 64, 64, 16)
|
||||
mma_64x64x8 = ( 64, 64, 8)
|
||||
|
||||
num_mma_per_tile = 4
|
||||
|
||||
# Cluster shapes (1, 1, 1) and (2, 2, 1) are valid,
|
||||
# but not included, because they tend not to perform as well.
|
||||
cluster_shapes = (
|
||||
(2, 1, 1),
|
||||
(1, 2, 1),
|
||||
)
|
||||
|
||||
fp16 = DataType.f16
|
||||
bf16 = DataType.bf16
|
||||
fp32 = DataType.f32
|
||||
s8 = DataType.s8
|
||||
s32 = DataType.s32
|
||||
|
||||
# When generating kernels, the usual way is to specify 4 types,
|
||||
# (A, B, Acc, C/D). Tests instead have 5 types,
|
||||
# (ElementAct, ElementFlt, ElementOut, ElementAcc, ElementCompute),
|
||||
# where ElementCompute is also called 'epi_type',
|
||||
# and corresponds to the type of epilogue activations.
|
||||
# This script maps tests' 5 types to 4 types
|
||||
# by making ElementCompute the same as ElementOut.
|
||||
|
||||
fp16_fp32_fp16_fp32 = {
|
||||
'a_type': fp16, # ElementAct(ivation)
|
||||
'b_type': fp16, # ElementF(i)lt(er)
|
||||
'c_type': fp32, # ElementAcc
|
||||
'd_type': fp32, # ElementOut (used only by CollectiveEpilogue)
|
||||
'acc_type': fp16, # ElementAcc
|
||||
'epi_type': fp32, # ElementCompute (used only by CollectiveEpilogue)
|
||||
}
|
||||
fp16_fp32_fp32_fp32 = {
|
||||
'a_type': fp16,
|
||||
'b_type': fp16,
|
||||
'c_type': fp32,
|
||||
'd_type': fp32,
|
||||
'acc_type': fp32,
|
||||
'epi_type': fp32,
|
||||
}
|
||||
fp32_fp32_fp32_fp32 = {
|
||||
'a_type': fp32,
|
||||
'b_type': fp32,
|
||||
'c_type': fp32,
|
||||
'd_type': fp32,
|
||||
'acc_type': fp32,
|
||||
'epi_type': fp32,
|
||||
}
|
||||
s8_s32_s32_s32 = {
|
||||
'a_type': s8,
|
||||
'b_type': s8,
|
||||
'c_type': s32,
|
||||
'd_type': s32,
|
||||
'acc_type': s32,
|
||||
'epi_type': s32,
|
||||
}
|
||||
|
||||
# Other NVIDIA libraries may have the habit of specifying data types like this.
|
||||
bf16bf16_bf16f32_f32 = {
|
||||
'a_type': bf16,
|
||||
'b_type': bf16,
|
||||
'c_type': fp32,
|
||||
'd_type': fp32,
|
||||
'acc_type': fp32,
|
||||
'epi_type': fp32,
|
||||
}
|
||||
f16f16_f16f16_f16 = {
|
||||
'a_type': fp16,
|
||||
'b_type': fp16,
|
||||
'c_type': fp16,
|
||||
'd_type': fp16,
|
||||
'acc_type': fp16,
|
||||
'epi_type': fp16,
|
||||
}
|
||||
f16f16_f16f32_f32 = {
|
||||
'a_type': fp16,
|
||||
'b_type': fp16,
|
||||
'c_type': fp16,
|
||||
'd_type': fp16,
|
||||
'acc_type': fp32,
|
||||
'epi_type': fp32,
|
||||
}
|
||||
f32f32_tf32f32_f32 = fp32_fp32_fp32_fp32
|
||||
|
||||
i8i8_i8i32_f32 = {
|
||||
'a_type': s8,
|
||||
'b_type': s8,
|
||||
'c_type': s32,
|
||||
'd_type': s32,
|
||||
'acc_type': s32,
|
||||
'epi_type': s32,
|
||||
}
|
||||
|
||||
# Each element in the outermost iterable is one combination of
|
||||
#
|
||||
# (ConvKind, spatial_dimension, data_types, byte_alignments, mma_sizes, cluster_sizes)
|
||||
#
|
||||
# for which to generate a kernel. spatial_dimension is the spatial
|
||||
# dimension of the convolution: either 1, 2, or 3. byte_alignments
|
||||
# is a triple of required minimum byte alignments for A, B, and C.
|
||||
#
|
||||
# Note that itertools functions produce a single-pass generator.
|
||||
# The code doesn't need a multipass iterable, but if one did, one
|
||||
# could call `tuple` or `list` on the generator.
|
||||
#
|
||||
# While this happens to use the same cluster sizes for each element,
|
||||
# the code doesn't require that. Different convolution kinds, data
|
||||
# types, or mma sizes might have different optimal cluster sizes.
|
||||
combinations_of_parameters = chain(
|
||||
# The following are all the kernels exercised in the unit tests.
|
||||
# Please try to keep in sync with the unit tests.
|
||||
product(
|
||||
(
|
||||
ConvKind.Fprop,
|
||||
),
|
||||
spatial_dims,
|
||||
(
|
||||
fp16_fp32_fp16_fp32,
|
||||
fp16_fp32_fp32_fp32,
|
||||
s8_s32_s32_s32,
|
||||
),
|
||||
all_byte_alignments,
|
||||
(
|
||||
mma_64x64x16,
|
||||
),
|
||||
cluster_shapes
|
||||
),
|
||||
product(
|
||||
(
|
||||
ConvKind.Fprop,
|
||||
),
|
||||
spatial_dims,
|
||||
(
|
||||
fp32_fp32_fp32_fp32,
|
||||
),
|
||||
all_byte_alignments,
|
||||
(
|
||||
mma_64x64x8,
|
||||
),
|
||||
cluster_shapes
|
||||
),
|
||||
product(
|
||||
(
|
||||
ConvKind.Dgrad,
|
||||
),
|
||||
spatial_dims,
|
||||
(
|
||||
fp16_fp32_fp16_fp32,
|
||||
fp16_fp32_fp32_fp32,
|
||||
),
|
||||
all_byte_alignments,
|
||||
(
|
||||
mma_64x64x16,
|
||||
),
|
||||
cluster_shapes
|
||||
),
|
||||
# Kernels not necessarily in the unit tests, but used elsewhere
|
||||
# and thus useful to have generated for profiling. They may
|
||||
# duplicate kernels above. All of them are 2-D. In general,
|
||||
# CUTLASS prefers 64 x 128 to 128 x 64 on sm90, even if the
|
||||
# hardware permits 128 x 64.
|
||||
(
|
||||
# Fprop
|
||||
#
|
||||
# bf16bf16_bf16f32_f32
|
||||
#
|
||||
# cluster shape (2, 1, 1)
|
||||
#
|
||||
(ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)),
|
||||
#
|
||||
# f16f16_f16f16_f16
|
||||
#
|
||||
# cluster shape (1, 1, 1)
|
||||
#
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 64, 8), (1, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 64, 16), (1, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 128, 8), (1, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 128, 16), (1, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 256, 8), (1, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 256, 16), (1, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 128, 8), (1, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 128, 16), (1, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 256, 8), (1, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 256, 16), (1, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 64, 8), (1, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 64, 16), (1, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 128, 8), (1, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 128, 16), (1, 1, 1)),
|
||||
#
|
||||
# f16f16_f16f32_f32
|
||||
#
|
||||
# cluster shape (2, 1, 1)
|
||||
#
|
||||
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 192, 8), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 192, 16), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 96, 8), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 96, 16), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)),
|
||||
#
|
||||
# f32f32_tf32f32_f32
|
||||
#
|
||||
# cluster shape (2, 1, 1)
|
||||
#
|
||||
(ConvKind.Fprop, 2, f32f32_tf32f32_f32, tma_byte_alignments, (128, 192, 8), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f32f32_tf32f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f32f32_tf32f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, f32f32_tf32f32_f32, tma_byte_alignments, (256, 96, 8), (2, 1, 1)),
|
||||
#
|
||||
# i8i8_i8i32_f32
|
||||
#
|
||||
# cluster shape (2, 1, 1)
|
||||
#
|
||||
(ConvKind.Fprop, 2, i8i8_i8i32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, i8i8_i8i32_f32, tma_byte_alignments, (128, 256, 32), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, i8i8_i8i32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)),
|
||||
(ConvKind.Fprop, 2, i8i8_i8i32_f32, tma_byte_alignments, (256, 128, 32), (2, 1, 1)),
|
||||
#
|
||||
# Dgrad
|
||||
#
|
||||
# bf16bf16_bf16f32_f32
|
||||
#
|
||||
# cluster shape (2, 1, 1)
|
||||
#
|
||||
(ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)),
|
||||
#
|
||||
# f16f16_f16f16_f16
|
||||
#
|
||||
# cluster shape (1, 1, 1)
|
||||
#
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 64, 8), (1, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 64, 16), (1, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 128, 8), (1, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 128, 16), (1, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 256, 8), (1, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 256, 16), (1, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 128, 8), (1, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 128, 16), (1, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 256, 8), (1, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 256, 16), (1, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 64, 8), (1, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 64, 16), (1, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 128, 8), (1, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 128, 16), (1, 1, 1)),
|
||||
#
|
||||
# f16f16_f16f32_f32
|
||||
#
|
||||
# cluster shape (2, 1, 1)
|
||||
#
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)),
|
||||
(ConvKind.Dgrad, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)),
|
||||
),
|
||||
)
|
||||
|
||||
# SM >= 90 kernels don't actually use warp_count, but the
|
||||
# TileDescription class needs it. The 4 in the default
|
||||
# warp_count has nothing to do with num_mma_per_tile.
|
||||
warp_count = [4, 1, 1]
|
||||
|
||||
stages = 0 # zero means "deduce the number of stages automatically"
|
||||
|
||||
mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecializedSm90
|
||||
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
|
||||
schedule_pairs = (
|
||||
(mainloop_schedule, epilogue_schedule),
|
||||
)
|
||||
tile_schedulers = (
|
||||
TileSchedulerType.Default, # -> void
|
||||
)
|
||||
|
||||
def make_math_instruction(data_types: Dict[str, DataType],
|
||||
mma_shape: Tuple[int, int, int]) -> MathInstruction:
|
||||
default_opcode = OpcodeClass.TensorOp
|
||||
default_math_op = MathOperation.multiply_add
|
||||
[A_data_type, B_data_type, C_data_type] = data_types
|
||||
return MathInstruction(
|
||||
instruction_shape,
|
||||
A_data_type, B_data_type, C_data_type,
|
||||
mma_shape,
|
||||
data_types['a_type'], data_types['b_type'], data_types['c_type'],
|
||||
default_opcode,
|
||||
default_math_op
|
||||
)
|
||||
data_types_and_instruction_shapes = [
|
||||
((DataType.f16, DataType.f16, DataType.f16), (64, 64, 16)),
|
||||
((DataType.f16, DataType.f16, DataType.f32), (64, 64, 16)),
|
||||
((DataType.bf16, DataType.bf16, DataType.f32), (64, 64, 16)),
|
||||
]
|
||||
math_instructions = map(lambda x: make_math_instruction(*x),
|
||||
data_types_and_instruction_shapes)
|
||||
|
||||
cluster_shapes = [
|
||||
[2, 1, 1],
|
||||
[1, 1, 1],
|
||||
]
|
||||
conv_kinds = [
|
||||
ConvKind.Fprop,
|
||||
ConvKind.Dgrad
|
||||
]
|
||||
mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecializedSm90
|
||||
stages = 0 # zero means "deduce the number of stages automatically"
|
||||
|
||||
# tile_descriptions is a 2-level list.
|
||||
# Each inner list is for each cluster shape.
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes:
|
||||
tile_shape = [
|
||||
math_inst.instruction_shape[0],
|
||||
math_inst.instruction_shape[1],
|
||||
math_inst.instruction_shape[2] * 4
|
||||
]
|
||||
warp_count = [4, 1, 1]
|
||||
tile_description = TileDescription(
|
||||
tile_shape, stages, warp_count, math_inst,
|
||||
minimum_compute_capability, maximum_compute_capability,
|
||||
cluster_shape)
|
||||
tile_descriptions.append(tile_description)
|
||||
|
||||
# It's typical to get the data types from the math instruction.
|
||||
data_type = {
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : math_inst.element_accumulator,
|
||||
"d_type" : math_inst.element_accumulator,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : math_inst.element_accumulator
|
||||
}
|
||||
|
||||
for conv_kind in conv_kinds:
|
||||
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
|
||||
schedule_pairs = [
|
||||
(mainloop_schedule, epilogue_schedule)
|
||||
]
|
||||
CreateConvOperator3x(manifest,
|
||||
dims_and_alignments = dims_and_alignments,
|
||||
tile_descriptions = tile_descriptions,
|
||||
data_types = data_type,
|
||||
schedule_pairs = schedule_pairs,
|
||||
tile_schedulers = [TileSchedulerType.Default], # -> void
|
||||
conv_kind = conv_kind,
|
||||
log_indent_level = log_indent_level)
|
||||
for (conv_kind, spatial_dim, data_types, byte_alignments, mma_shape, cluster_shape) in combinations_of_parameters:
|
||||
math_inst = make_math_instruction(data_types, mma_shape)
|
||||
tile_shape = (mma_shape[0], mma_shape[1], num_mma_per_tile * mma_shape[2])
|
||||
tile_description = TileDescription(tile_shape, stages, warp_count, math_inst,
|
||||
minimum_compute_capability, maximum_compute_capability, cluster_shape)
|
||||
assert(isinstance(spatial_dim, int))
|
||||
assert(isinstance(byte_alignments, dict))
|
||||
dims_and_alignments = (
|
||||
(
|
||||
(spatial_dim, byte_alignments['A']),
|
||||
(spatial_dim, byte_alignments['B']),
|
||||
(spatial_dim, byte_alignments['C']),
|
||||
),
|
||||
)
|
||||
CreateConvOperator3x(manifest,
|
||||
dims_and_alignments = dims_and_alignments,
|
||||
tile_descriptions = [tile_description],
|
||||
data_types = data_types,
|
||||
schedule_pairs = schedule_pairs,
|
||||
tile_schedulers = tile_schedulers,
|
||||
conv_kind = conv_kind,
|
||||
log_indent_level = log_indent_level)
|
||||
|
||||
def GenerateSM90(manifest, cuda_version):
|
||||
GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version)
|
||||
@ -6738,8 +7008,8 @@ def GenerateSM90(manifest, cuda_version):
|
||||
|
||||
def numeric_log_level(log_level: str) -> int:
|
||||
"""
|
||||
Converts the string identifier of the log level into the numeric identifier used
|
||||
in setting the log level
|
||||
Converts the string identifier of the log level
|
||||
into the numeric identifier used in setting the log level.
|
||||
|
||||
:param x: string representation of log level (e.g., 'INFO', 'DEBUG')
|
||||
:type x: str
|
||||
@ -6762,8 +7032,18 @@ def define_parser():
|
||||
parser.add_argument("--curr-build-dir", default=".", help="CUTLASS current build directory. cmake files will be emitted in this directory")
|
||||
parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.")
|
||||
parser.add_argument("--architectures", default='53;60;61;70;75;80;90', help="Target compute architectures")
|
||||
parser.add_argument("--kernels", default='', help='Comma delimited list to filter kernels by name.')
|
||||
parser.add_argument("--ignore-kernels", default='', help='Comma delimited list of kernels to exclude from build.')
|
||||
parser.add_argument("--kernels", default='', help='Comma-delimited list to filter kernels by name. ' +
|
||||
'Specifying this as \"all\" includes ALL the kernels, ' +
|
||||
'while not specifying this includes only the default set of kernels.')
|
||||
parser.add_argument("--ignore-kernels", default='', help='Comma-delimited list of kernels ' +
|
||||
'to exclude from build. For backwards compatibility reasons, ' +
|
||||
'this option only takes effect if --kernels is set to a nonempty value.')
|
||||
parser.add_argument("--exclude-kernels", default='', help='Comma-delimited list of kernels ' +
|
||||
'to exclude from build. In contrast to --ignore-kernels, ' +
|
||||
'this option always takes effect, ' +
|
||||
'whether or not --kernels is set to a nonempty value. ' +
|
||||
'It also can exclude kernels from the filter file ' +
|
||||
'(see --kernel-filter-file option below).')
|
||||
parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose compute capability range is not satisfied by the build target are excluded.')
|
||||
parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit")
|
||||
parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file')
|
||||
|
||||
@ -506,6 +506,7 @@ class Manifest:
|
||||
self.operations_enabled = []
|
||||
self.selected_kernels = []
|
||||
self.ignore_kernel_names = []
|
||||
self.exclude_kernel_names = []
|
||||
self.compute_capabilities = [50,]
|
||||
self.curr_build_dir = '.'
|
||||
self.filter_by_cc = True
|
||||
@ -546,6 +547,7 @@ class Manifest:
|
||||
self.kernel_names = [x for x in args.kernels.split(',') if x != '']
|
||||
|
||||
self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
|
||||
self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != '']
|
||||
|
||||
if args.kernel_filter_file is None:
|
||||
self.kernel_filter_list = []
|
||||
@ -612,41 +614,54 @@ class Manifest:
|
||||
if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled:
|
||||
return False
|
||||
|
||||
name = operation.procedural_name()
|
||||
|
||||
# eliminate duplicates
|
||||
if operation.procedural_name() in self.operations_by_name.keys():
|
||||
if name in self.operations_by_name.keys():
|
||||
return False
|
||||
|
||||
# Filter based on list of valid substrings
|
||||
if len(self.kernel_names):
|
||||
name = operation.procedural_name()
|
||||
enabled = False
|
||||
|
||||
# compare against the include list
|
||||
for name_substr in self.kernel_names:
|
||||
if self._filter_string_matches(name_substr, name):
|
||||
_LOGGER.debug("Kernel {kernel} included due to filter string '{filt}'.".format(
|
||||
kernel = operation.procedural_name(),
|
||||
filt = name_substr))
|
||||
_LOGGER.debug(f"Kernel {name} included due to filter string '{name_substr}'.")
|
||||
enabled = True
|
||||
break
|
||||
else:
|
||||
_LOGGER.debug(f"Kernel {name} NOT included due to not matching '{name_substr}'.")
|
||||
|
||||
# compare against the exclude list
|
||||
for name_substr in self.ignore_kernel_names:
|
||||
if self._filter_string_matches(name_substr, name):
|
||||
_LOGGER.debug("Kernel {kernel} ignored due to filter string '{filt}'.".format(
|
||||
kernel = operation.procedural_name(),
|
||||
filt = name_substr))
|
||||
_LOGGER.debug(f"Kernel {name} ignored due to filter string '{name_substr}'.")
|
||||
enabled = False
|
||||
break
|
||||
else:
|
||||
_LOGGER.debug(f"Kernel {name} NOT ignored due to not matching '{name_substr}'.")
|
||||
|
||||
if len(self.kernel_filter_list) > 0:
|
||||
if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list):
|
||||
_LOGGER.debug("Kernel {kernel} matched via kernel filter file.".format(kernel = operation.procedural_name()))
|
||||
enabled = True
|
||||
else:
|
||||
_LOGGER.debug("Kernel {kernel} culled due to no match in kernel filter file.".format(kernel = operation.procedural_name()))
|
||||
enabled = False
|
||||
if self.filter_out_kernels(name, self.kernel_filter_list):
|
||||
_LOGGER.debug(f"Kernel {name} matched via kernel filter file.")
|
||||
enabled = True
|
||||
else:
|
||||
_LOGGER.debug(f"Kernel {name} culled due to no match in kernel filter file.")
|
||||
enabled = False
|
||||
|
||||
# CUTLASS_LIBRARY_IGNORE_KERNELS ("ignore" list) only takes effect
|
||||
# if CUTLASS_LIBRARY_KERNELS was specified.
|
||||
# Changing that would break backwards compatibility.
|
||||
# Thus, CUTLASS has introduced the new CMake option CUTLASS_LIBRARY_EXCLUDE_KERNELS,
|
||||
# that always takes effect, whether or not CUTLASS_LIBRARY_KERNELS was specified.
|
||||
for name_substr in self.exclude_kernel_names:
|
||||
if self._filter_string_matches(name_substr, name):
|
||||
_LOGGER.debug(f"Kernel {name} excluded due to filter string '{name_substr}'.")
|
||||
enabled = False
|
||||
break
|
||||
else:
|
||||
_LOGGER.debug(f"Kernel {name} NOT excluded due to not matching '{name_substr}'.")
|
||||
|
||||
# TODO: filter based on compute data type
|
||||
return enabled
|
||||
|
||||
@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='cutlass_library',
|
||||
version='3.5.0',
|
||||
version='3.5.1',
|
||||
description='CUTLASS library generation scripts',
|
||||
packages=['cutlass_library']
|
||||
)
|
||||
|
||||
@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='pycute',
|
||||
version='3.5.0',
|
||||
version='3.5.1',
|
||||
description='Python implementation of CuTe',
|
||||
packages=['pycute'],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user