releaase 2.11 (#703)

This commit is contained in:
Aditya Atluri
2022-11-19 06:02:15 -08:00
committed by GitHub
parent 3c90f6aea6
commit c975e2ccbb
329 changed files with 47332 additions and 10607 deletions

View File

@ -17,7 +17,8 @@ from library import *
class Conv2dOperation:
#
def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1):
stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \
group_mode = GroupMode.NoneGroup):
self.operation_kind = OperationKind.Conv2d
self.arch = arch
@ -31,6 +32,7 @@ class Conv2dOperation:
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor
self.group_mode = group_mode
#
def is_complex(self):
complex_operators = [
@ -95,17 +97,18 @@ class Conv2dOperation:
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
threadblock = "%dx%d_%dx%d" % (
self.tile_description.threadblock_shape[0],
self.tile_description.threadblock_shape[1],
self.tile_description.threadblock_shape[2],
self.tile_description.stages
)
threadblock = self.tile_description.procedural_name()
# grouped conv
if self.group_mode != GroupMode.NoneGroup:
group_conv_name = f"{GroupModeNames[self.group_mode]}_"
else:
group_conv_name = ""
if self.stride_support == StrideSupport.Unity:
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}"
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}"
else:
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}"
return SubstituteTemplate(
configuration_name,
@ -115,6 +118,7 @@ class Conv2dOperation:
'threadblock': threadblock,
'layout': self.layout_name(),
'alignment': "%d" % self.A.alignment,
'group_conv_name': group_conv_name
}
)
@ -162,7 +166,77 @@ class EmitConv2dInstance:
${align_b}
>::Kernel;
"""
self.template_group_conv = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name}_base =
typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}<
${element_a},
${layout_a},
${element_b},
${layout_b},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>,
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
${stages},
${math_operator},
${group_mode},
${iterator_algorithm},
${stride_support},
${align_a},
${align_b}
>::Kernel;
"""
self.template_depthwise_direct_conv = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name}_base =
typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}<
${element_a},
${layout_a},
${element_b},
${layout_b},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>,
cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue},
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>,
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
${threadblock_output_shape_n},
${threadblock_output_shape_p},
${threadblock_output_shape_q}>,
${stages},
${math_operator},
${iterator_algorithm},
${stride_support},
cutlass::MatrixShape<${stride_r}, ${stride_s}>,
cutlass::MatrixShape<${dilation_r}, ${dilation_s}>
>::Kernel;
"""
def emit(self, operation):
@ -206,7 +280,32 @@ class EmitConv2dInstance:
'align_b': str(operation.B.alignment),
}
return SubstituteTemplate(self.template, values)
if operation.group_mode == GroupMode.NoneGroup:
return SubstituteTemplate(self.template, values)
elif operation.group_mode == GroupMode.Depthwise:
values['group_mode'] = GroupModeTag[operation.group_mode]
# Setup other template params
values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0])
values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1])
values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2])
values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3])
values['filter_shape_r'] = str(operation.tile_description.filter_shape[0])
values['filter_shape_s'] = str(operation.tile_description.filter_shape[1])
values['stride_r'] = str(operation.tile_description.stride[0])
values['stride_s'] = str(operation.tile_description.stride[1])
values['dilation_r'] = str(operation.tile_description.dilation[0])
values['dilation_s'] = str(operation.tile_description.dilation[1])
return SubstituteTemplate(self.template_depthwise_direct_conv, values)
else:
values['group_mode'] = GroupModeTag[operation.group_mode]
return SubstituteTemplate(self.template_group_conv, values)
###################################################################################################
#
@ -292,6 +391,16 @@ void initialize_${configuration_name}(Manifest &manifest) {
Operation_${operation_name}>(
"${operation_name}"));
"""
self.configuration_direct_conv_instance = """
using Operation_${operation_name} = cutlass::conv::device::DirectConvolution<
${operation_name}>;
manifest.append(new cutlass::library::DirectConv2dOperation<
Operation_${operation_name}>(
"${operation_name}"));
"""
self.configuration_epilogue = """
@ -334,10 +443,16 @@ void initialize_${configuration_name}(Manifest &manifest) {
}))
for operation in self.operations:
self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
'configuration_name': self.configuration_name,
'operation_name': operation.procedural_name()
}))
if operation.group_mode == GroupMode.Depthwise:
self.configuration_file.write(SubstituteTemplate(self.configuration_direct_conv_instance, {
'configuration_name': self.configuration_name,
'operation_name': operation.procedural_name()
}))
else:
self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
'configuration_name': self.configuration_name,
'operation_name': operation.procedural_name()
}))
self.configuration_file.write(self.configuration_epilogue)
self.configuration_file.write(self.epilogue_template)

View File

@ -11,6 +11,8 @@ import argparse
from library import *
from manifest import *
from itertools import product
###################################################################################################
#
@ -49,6 +51,8 @@ def EpilogueAlignment(max_alignment, tile, epilogue_steps = 8):
def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
swizzling_functor = SwizzlingFunctor.Identity8):
# Use StreamK decomposition for basic GEMMs
# swizzling_functor = SwizzlingFunctor.StreamK):
if complex_transforms is None:
complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]
@ -373,11 +377,26 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme
# Strided support for Analytic and Optimized Fprop
for iterator_algorithm in iterator_algorithms:
new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
manifest.append(new_operation)
operations.append(new_operation)
new_operations = [
# None grouped kernel
Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_),
]
# Instance group conv kernel
if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC:
# SingleGroup kernel
new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup))
# Analytic iterator supports MultipleGroup mode
if iterator_algorithm == IteratorAlgorithm.Analytic:
new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup))
for new_operation in new_operations:
manifest.append(new_operation)
operations.append(new_operation)
#
# Conv2d Dgrad
@ -593,6 +612,62 @@ def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignme
return operations
# Convolution for Depthwise 2d conv
def CreateDepthwiseConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \
conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
element_a, element_b, element_c, element_epilogue = data_type
# iterator algorithm (FixedStrideDilation, Optimized)
iterator_algorithms = [IteratorAlgorithm.FixedStrideDilation, IteratorAlgorithm.Optimized]
# by default, only generate the largest tile size, largest alignment, and optimized iterator
if manifest.kernel_filter == '':
tile_descriptions = [tile_descriptions[0],]
alignment_constraints = [alignment_constraints[0],]
operations = []
for tile in tile_descriptions:
for alignment in alignment_constraints:
alignment_c = min(8, alignment)
A = TensorDescription(element_a, layout[0], alignment)
B = TensorDescription(element_b, layout[1], alignment)
C = TensorDescription(element_c, layout[2], alignment_c)
swizzling_functor_ = swizzling_functor
if ConvKind.Fprop in conv_kinds:
# Strided support for Optimized and FixedStridedDilation Depthwise Conv
for iterator_algorithm in iterator_algorithms:
stride_support = StrideSupport.Strided
if iterator_algorithm == IteratorAlgorithm.FixedStrideDilation:
if tile.stride == [-1, -1] or tile.dilation == [-1,-1]:
continue
stride_support = StrideSupport.Fixed
if iterator_algorithm == IteratorAlgorithm.Optimized:
if tile.stride != [-1, -1] or tile.dilation != [-1,-1]:
continue
new_operation = Conv2dOperation(ConvKind.Fprop,
iterator_algorithm,
tile.minimum_compute_capability,
tile,
A, B, C,
element_epilogue,
stride_support,
epilogue_functor,
swizzling_functor_,
group_mode=GroupMode.Depthwise)
manifest.append(new_operation)
operations.append(new_operation)
return operations
###################################################################################################
###################################################################################################
@ -748,10 +823,83 @@ def GenerateSM60_Simt(manifest, cuda_version):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)
#
def GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version):
math_instructions = [
MathInstruction( \
[1, 1, 1], \
DataType.f16, DataType.f16, DataType.f16, \
OpcodeClass.Simt, \
MathOperation.multiply_add),
]
min_cc = 60
max_cc = 1024
alignment_constraints = [8,]
filter_3x3 = [3, 3]
filter_5x5 = [5, 5]
# [stride_h, stride_w]
# [-1, -1] means all stride size.
strides = [[-1,-1], [1, 1], [2, 2]]
# [dilation_h, dilation_w]
# [-1, -1] means all dilation size.
dilations = [[-1,-1], [1, 1], [2, 2]]
#groups per thread block
g16 = 16
g32 = 32
g64 = 64
#output shape per thread block
npq_1x4x4 = [1, 4, 4]
npq_1x8x8 = [1, 8, 8]
npq_1x10x10 = [1, 10, 10]
tile_descriptions = []
for math_inst in math_instructions:
for stride, dilation in product(strides, dilations):
tile_descriptions.extend([
# filter3x3 ThreadBlock_output, filter, stage, warp
Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_3x3, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc),
Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_3x3, 4, stride, dilation,[4, 1, 1], math_inst, min_cc, max_cc),
Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc),
# filter5x5 ThreadBlock_output, filter, stage, warp
Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_5x5, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc)
])
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_accumulator,
]
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateDepthwiseConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
#
#
def GenerateSM60(manifest, cuda_version):
GenerateSM60_Simt(manifest, cuda_version)
GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version)
###################################################################################################
###################################################################################################
@ -3813,6 +3961,627 @@ def GenerateSM80(manifest, cuda_version):
GenerateSM80_Simt_complex(manifest, cuda_version)
###################################################################################################
#
def GenerateSM90_TensorOp_1684(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
return
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
]
math_inst = \
MathInstruction( \
[16, 8, 4], \
DataType.f64, DataType.f64, DataType.f64, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add)
min_cc = 90
max_cc = 1024
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 32, 16], 3, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([32, 256, 16], 3, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc),
]
data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64]
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)
#
#
def GenerateSM90_TensorOp_1684_complex(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
return
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
]
math_inst = \
MathInstruction( \
[16, 8, 4], \
DataType.f64, DataType.f64, DataType.f64, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_complex)
min_cc = 90
max_cc = 1024
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([128, 64, 8 ], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 8 ], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 8 ], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([16, 32, 8 ], 4, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 16, 8 ], 4, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([16, 32, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 16, 16], 3, [2, 1, 1], math_inst, min_cc, max_cc),
]
data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
complex_transforms = [
(ComplexTransform.none, ComplexTransform.none),
(ComplexTransform.conj, ComplexTransform.none),
(ComplexTransform.none, ComplexTransform.conj),
(ComplexTransform.conj, ComplexTransform.conj)
]
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints, complex_transforms)
#
#
def GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
return
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
]
math_inst = \
MathInstruction( \
[16, 8, 4], \
DataType.f64, DataType.f64, DataType.f64, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_complex_gaussian)
min_cc = 90
max_cc = 1024
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
]
data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
complex_transforms = [
(ComplexTransform.none, ComplexTransform.none),
(ComplexTransform.conj, ComplexTransform.none),
(ComplexTransform.none, ComplexTransform.conj),
(ComplexTransform.conj, ComplexTransform.conj)
]
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints, complex_transforms)
#
#
def GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
return
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor),
]
fill_modes = [
FillMode.Lower, FillMode.Upper,
]
math_inst = \
MathInstruction( \
[16, 8, 4], \
DataType.f64, DataType.f64, DataType.f64, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add)
min_cc = 90
max_cc = 1024
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc),
]
data_type = [DataType.f64, DataType.f64, DataType.f64]
CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
data_type, alignment_constraints, BlasMode.symmetric)
#
#
def GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
return
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor),
]
fill_modes = [
FillMode.Lower, FillMode.Upper,
]
math_inst = \
MathInstruction( \
[16, 8, 4], \
DataType.f64, DataType.f64, DataType.f64, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_complex)
min_cc = 90
max_cc = 1024
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
]
data_type = [DataType.cf64, DataType.cf64, DataType.cf64]
# SYRK computation
CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
data_type, alignment_constraints, BlasMode.symmetric)
# HERK computation
CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
data_type, alignment_constraints, BlasMode.hermitian)
#
#
def GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
return
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor),
]
fill_modes = [
FillMode.Lower, FillMode.Upper,
]
math_inst = \
MathInstruction( \
[16, 8, 4], \
DataType.f64, DataType.f64, DataType.f64, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_complex_gaussian)
min_cc = 90
max_cc = 1024
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
]
data_type = [DataType.cf64, DataType.cf64, DataType.cf64]
complex_transforms = [ComplexTransform.none,]
# SYRK computation
CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
data_type, alignment_constraints, BlasMode.symmetric)
# HERK computation
CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
data_type, alignment_constraints, BlasMode.hermitian)
#
#
def GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
return
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
]
side_modes = [
SideMode.Left, SideMode.Right,
]
fill_modes = [
FillMode.Lower, FillMode.Upper,
]
diag_types = [
DiagType.NonUnit, DiagType.Unit,
]
math_inst = \
MathInstruction( \
[16, 8, 4], \
DataType.f64, DataType.f64, DataType.f64, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add)
min_cc = 90
max_cc = 1024
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64]
CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
data_type, alignment_constraints)
#
#
def GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
return
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
]
side_modes = [
SideMode.Left, SideMode.Right,
]
fill_modes = [
FillMode.Lower, FillMode.Upper,
]
diag_types = [
DiagType.NonUnit, DiagType.Unit,
]
math_inst = \
MathInstruction( \
[16, 8, 4], \
DataType.f64, DataType.f64, DataType.f64, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_complex)
min_cc = 90
max_cc = 1024
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
complex_transforms = [
ComplexTransform.none, ComplexTransform.conj,
]
CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
data_type, alignment_constraints, complex_transforms)
#
#
def GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
return
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
]
side_modes = [
SideMode.Left, SideMode.Right,
]
fill_modes = [
FillMode.Lower, FillMode.Upper,
]
diag_types = [
DiagType.NonUnit, DiagType.Unit,
]
math_inst = \
MathInstruction( \
[16, 8, 4], \
DataType.f64, DataType.f64, DataType.f64, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_complex_gaussian)
min_cc = 90
max_cc = 1024
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
complex_transforms = [
ComplexTransform.none, ComplexTransform.conj,
]
CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
data_type, alignment_constraints, complex_transforms)
#
#
def GenerateSM90_TensorOp_1684_symm(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
return
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor),
]
side_modes = [
SideMode.Left, SideMode.Right,
]
fill_modes = [
FillMode.Lower, FillMode.Upper,
]
math_inst = \
MathInstruction( \
[16, 8, 4], \
DataType.f64, DataType.f64, DataType.f64, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add)
min_cc = 90
max_cc = 1024
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc),
]
data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64]
CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
data_type, alignment_constraints, BlasMode.symmetric)
#
#
def GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
return
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor),
]
side_modes = [
SideMode.Left, SideMode.Right,
]
fill_modes = [
FillMode.Lower, FillMode.Upper,
]
math_inst = \
MathInstruction( \
[16, 8, 4], \
DataType.f64, DataType.f64, DataType.f64, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_complex)
min_cc = 90
max_cc = 1024
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
]
data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
# SYMM computation
CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
data_type, alignment_constraints, BlasMode.symmetric)
# HEMM computation
CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
data_type, alignment_constraints, BlasMode.hermitian)
#
#
def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
return
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor),
]
side_modes = [
SideMode.Left, SideMode.Right,
]
fill_modes = [
FillMode.Lower, FillMode.Upper,
]
math_inst = \
MathInstruction( \
[16, 8, 4], \
DataType.f64, DataType.f64, DataType.f64, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_complex_gaussian)
min_cc = 90
max_cc = 1024
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
#TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
]
data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
complex_transforms = [ComplexTransform.none,]
# SYMM computation
CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
data_type, alignment_constraints, BlasMode.symmetric)
# HEMM computation
CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
data_type, alignment_constraints, BlasMode.hermitian)
#
###################################################################################################
#
def GenerateSM90(manifest, cuda_version):
GenerateSM90_TensorOp_1684(manifest, cuda_version)
GenerateSM90_TensorOp_1684_complex(manifest, cuda_version)
GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version)
GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version)
GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version)
GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version)
GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version)
GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version)
GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version)
GenerateSM90_TensorOp_1684_symm(manifest, cuda_version)
GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version)
GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version)
###################################################################################################
if __name__ == "__main__":
@ -3842,6 +4611,8 @@ if __name__ == "__main__":
GenerateSM70(manifest, args.cuda_version)
GenerateSM75(manifest, args.cuda_version)
GenerateSM80(manifest, args.cuda_version)
GenerateSM90(manifest, args.cuda_version)
if 'library' in args.generator_target.split(','):
manifest.emit(GeneratorTarget.Library)

View File

@ -471,9 +471,11 @@ SharedMemPerCC = {
70: 96, # 96KB of SMEM
72: 96, # 96KB of SMEM
75: 64, # 64KB of SMEM
80: 160, # 164KB of SMEM - 4KB reserved for the driver
86: 100, # 100KB of SMEM
87: 160, # 164KB of SMEM - 4KB reserved for the driver
80: 163, # 163KB of SMEM - 1KB reserved for the driver
86: 99, # 99KB of SMEM - 1KB reserved for the driver
87: 163, # 163KB of SMEM - 1KB reserved for the driver
89: 99, # 99KB of SMEM - 1KB reserved for the driver
90: 227, # 227KB of SMEM - 1KB reserved for the driver
}
###################################################################################################
@ -561,7 +563,8 @@ class SwizzlingFunctor(enum.Enum):
StridedDgradIdentity1 = enum_auto()
StridedDgradIdentity4 = enum_auto()
StridedDgradHorizontal = enum_auto()
StreamK = enum_auto()
#
SwizzlingFunctorTag = {
SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>',
@ -572,6 +575,7 @@ SwizzlingFunctorTag = {
SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>',
SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>',
SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle',
SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK',
}
#
@ -618,38 +622,65 @@ class IteratorAlgorithm(enum.Enum):
Optimized = enum_auto()
FixedChannels = enum_auto()
FewChannels = enum_auto()
FixedStrideDilation = enum_auto()
#
IteratorAlgorithmTag = {
IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic',
IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized',
IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels',
IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels'
IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels',
IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation'
}
IteratorAlgorithmNames = {
IteratorAlgorithm.Analytic: 'analytic',
IteratorAlgorithm.Optimized: 'optimized',
IteratorAlgorithm.FixedChannels: 'fixed_channels',
IteratorAlgorithm.FewChannels: 'few_channels'
IteratorAlgorithm.FewChannels: 'few_channels',
IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation'
}
#
class StrideSupport(enum.Enum):
Strided = enum_auto()
Unity = enum_auto()
Fixed = enum_auto()
#
StrideSupportTag = {
StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided',
StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity',
StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed'
}
StrideSupportNames = {
StrideSupport.Strided: '',
StrideSupport.Unity: 'unity_stride',
StrideSupport.Fixed: 'fixed_stride'
}
#
class GroupMode(enum.Enum):
NoneGroup = enum_auto() # dense conv (G=1)
SingleGroup = enum_auto() # grouped convolution (single group per CTA)
MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA)
Depthwise = enum_auto() # Depthwise convolution ( C=K=G )
#
GroupModeTag = {
GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone',
GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup',
GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup',
GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise',
}
GroupModeNames = {
GroupMode.NoneGroup: '',
GroupMode.SingleGroup: 'single_group',
GroupMode.MultipleGroup: 'multiple_group',
GroupMode.Depthwise: 'depthwise',
}
###################################################################################################
@ -677,6 +708,39 @@ class TileDescription:
def procedural_name(self):
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
#
class Direct2dConvFixedStrideDilationTileDescription:
def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute):
self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]]
self.threadblock_output_shape = threadblock_output_shape
self.filter_shape = filter_shape
self.stages = stages
self.warp_count = warp_count
self.stride = stride
self.dilation = dilation
self.math_instruction = math_instruction
self.minimum_compute_capability = min_compute
self.maximum_compute_capability = max_compute
def procedural_name(self):
str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],
self.threadblock_shape[1],
self.threadblock_shape[2],
self.threadblock_output_shape[0],
self.threadblock_output_shape[1],
self.threadblock_output_shape[2],
self.threadblock_output_shape[3],
self.stages,
self.filter_shape[0],
self.filter_shape[1])
# Fixed Strided and dilation
if self.stride != [-1, -1] and self.dilation != [-1, -1]:
str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0],
self.stride[1],
self.dilation[0],
self.dilation[1])
return str_name
#
class TensorDescription:
def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none):

View File

@ -85,18 +85,11 @@ You can run the PyCUTLASS on NGC PyTorch container.
```shell
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:22.09-py3
```
PyCUTLASS requires additional dependency Boost C++ library, which can be installed with
```bash
apt-get update
apt-get -y install libboost-all-dev
```
### Environment variables
PyCUTLASSS requires two environment variables:
* `CUTLASS_PATH`: the root directory of CUTLASS
* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed
* `CUTLASS_PATH`: the root directory of CUTLASS. You can set this from the location at which you cloned CUTLASS via: `export CUTLASS_PATH=$(pwd)`.
* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed. If running in bash with `nvcc` installed under a CUDA toolkit, you can set this to the location of your `nvcc` installation via: `export CUDA_INSTALL_PATH=$(which nvcc | awk -F'/bin/nvcc' '{print $1}')`
After setting these two environment variables, PyCUTLASS can be installed with
```shell

View File

@ -38,6 +38,7 @@
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/params_universal_base.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
@ -104,16 +105,12 @@ public:
//
/// Argument structure
struct Arguments {
struct Arguments : UniversalArgumentsBase {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
typename EpilogueVisitor::Arguments epilogue_visitor;
void const * ptr_A;
@ -124,7 +121,6 @@ public:
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int64_t batch_stride_D;
typename LayoutA::Stride stride_a;
typename LayoutB::Stride stride_b;
@ -145,8 +141,6 @@ public:
//
Arguments():
mode(GemmUniversalMode::kGemm),
batch_count(1),
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
ptr_gather_A_indices(nullptr),
ptr_gather_B_indices(nullptr),
@ -174,12 +168,10 @@ public:
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue_visitor(epilogue_visitor),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices) {
@ -212,12 +204,10 @@ public:
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue_visitor(epilogue_visitor),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices) {
@ -248,11 +238,19 @@ public:
//
/// Parameters structure
struct Params {
struct Params : UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC> {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
using ParamsBase = UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC>;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
@ -261,10 +259,6 @@ public:
typename EpilogueVisitor::Params epilogue_visitor;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void * ptr_A;
void * ptr_B;
void * ptr_C;
@ -273,7 +267,6 @@ public:
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int64_t batch_stride_D;
int * ptr_gather_A_indices;
int * ptr_gather_B_indices;
@ -285,47 +278,21 @@ public:
// Methods
//
CUTLASS_HOST_DEVICE
Params():
swizzle_log_tile(0),
params_A(0),
params_B(0),
params_C(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
batch_stride_A(0),
batch_stride_B(0),
batch_stride_C(0),
batch_stride_D(0),
ptr_gather_A_indices(nullptr),
ptr_gather_B_indices(nullptr),
ptr_scatter_D_indices(nullptr),
semaphore(nullptr) { }
/// Default constructor
Params() = default;
CUTLASS_HOST_DEVICE
Params(
Arguments const &args,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
int gemm_k_size,
void *workspace = nullptr
int device_sms,
int sm_occupancy
):
problem_size(args.problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
ParamsBase(args, device_sms, sm_occupancy),
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
epilogue_visitor(args.epilogue_visitor),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(gemm_k_size),
ptr_A(const_cast<void *>(args.ptr_A)),
ptr_B(const_cast<void *>(args.ptr_B)),
ptr_C(const_cast<void *>(args.ptr_C)),
@ -333,11 +300,9 @@ public:
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
batch_stride_C(args.batch_stride_C),
batch_stride_D(args.batch_stride_D),
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)),
semaphore(static_cast<int *>(workspace)) {
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)) {
}
@ -358,7 +323,6 @@ public:
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
batch_stride_D = args.batch_stride_D;
epilogue_visitor = args.epilogue_visitor;
@ -466,12 +430,6 @@ public:
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {

View File

@ -38,11 +38,19 @@
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/conv/threadblock/threadblock_swizzle.h"
#include <boost/core/demangle.hpp>
#include <cxxabi.h>
#include <cuda_runtime.h>
namespace py = pybind11;
std::string demangle(const char* mangled_name) {
std::size_t len = 0;
int status = 0;
std::unique_ptr<char> ptr(
__cxxabiv1::__cxa_demangle(mangled_name, nullptr, &len, &status));
return ptr.get();
}
template<typename T>
void bind_identity_swizzle(py::module & m, std::string name) {
py::class_<T>(m, name.c_str(),
@ -80,7 +88,7 @@ void bind_identity_swizzle(py::module & m, std::string name) {
py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return boost::core::demangle(typeid(T).name());
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
}
@ -101,7 +109,7 @@ void bind_swizzle(py::module & m, std::string name, std::string doc) {
py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return boost::core::demangle(typeid(T).name());
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
}
@ -124,7 +132,7 @@ void bind_dgrad_swizzle(py::module & m, std::string name) {
}, py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return boost::core::demangle(typeid(T).name());
return demangle(typeid(T).name());
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
}

View File

@ -69,9 +69,12 @@ def get_gemm_arguments(epilogue_functor):
class _GemmArguments(ctypes.Structure):
_fields_ = [
# Arguments from UniversalArgumentsBase
("mode", ctypes.c_int),
("problem_size", GemmCoord_),
("batch_count", ctypes.c_int),
("batch_stride_D", ctypes.c_longlong),
# Remaining arguments
("epilogue", _EpilogueOutputOpParams),
("ptr_A", ctypes.c_void_p),
("ptr_B", ctypes.c_void_p),
@ -80,7 +83,6 @@ def get_gemm_arguments(epilogue_functor):
("batch_stride_A", ctypes.c_longlong),
("batch_stride_B", ctypes.c_longlong),
("batch_stride_C", ctypes.c_longlong),
("batch_stride_D", ctypes.c_longlong),
("stride_a", ctypes.c_longlong),
("stride_b", ctypes.c_longlong),
("stride_c", ctypes.c_longlong),

View File

@ -229,7 +229,7 @@ class GemmArguments(ArgumentBase):
elif operand in ["c", "d"]:
tensor_coord = problem_size.mn()
else:
raise ValueError("unknonw operand: " + operand)
raise ValueError("unknown operand: " + operand)
layout = tensor_layout.packed(tensor_coord)
@ -245,22 +245,27 @@ class GemmArguments(ArgumentBase):
)
if self.gemm_mode == cutlass.gemm.Mode.Array:
arguments = self.operation.argument_type(
self.gemm_mode, problem_size_, self.batch_count, self.output_op,
# Arguments from UniversalArgumentsBase
self.gemm_mode, problem_size_, self.batch_count, 0,
# Remaining arguments
self.output_op,
int(self.ptr_A_array_buffer.ptr),
int(self.ptr_B_array_buffer.ptr),
int(self.ptr_C_array_buffer.ptr),
int(self.ptr_D_array_buffer.ptr),
0, 0, 0, 0,
0, 0, 0,
self.lda, self.ldb, self.ldc, self.ldd,
self.lda, self.ldb, self.ldc, self.ldd,
0, 0, 0
)
else:
arguments = self.operation.argument_type(
self.gemm_mode, problem_size_, self.batch_count, self.output_op,
# Arguments from UniversalArgumentsBase
self.gemm_mode, problem_size_, self.batch_count, self.batched_stride_D,
# Remaining arguments
self.output_op,
int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D),
self.batched_stride_A, self.batched_stride_B, self.batched_stride_C,
self.batched_stride_D,
self.lda, self.ldb, self.ldc, self.ldd,
self.lda, self.ldb, self.ldc, self.ldd,
0, 0, 0
@ -299,8 +304,7 @@ class GemmArguments(ArgumentBase):
arguments, grid_tiled_shape, gemm_k_size = self.arguments
res_arg = self.operation.rt_module.get_args(
ctypes.byref(arguments), ctypes.byref(grid_tiled_shape),
gemm_k_size, ctypes.c_void_p(int(device_workspace)))
ctypes.byref(arguments), ctypes.c_void_p(int(device_workspace)))
host_workspace = bytearray(res_arg.contents)
device_workspace = None
@ -582,10 +586,15 @@ extern "C" {
}
// Get the params as byte array
char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, \
cutlass::gemm::GemmCoord* grid_tiled_shape, int gemm_k_size, int* workspace){
char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int* workspace){
${operation_name}_base::Params* params;
params = new ${operation_name}_base::Params(*argument, *grid_tiled_shape, gemm_k_size, workspace);
params = new ${operation_name}_base::Params(*argument,
-1, // SM count. Only used for stream-K
-1 // Occupancy. Only used for stream-K
);
// Semaphore holds the pointer to the workspace in the Params struct
params->semaphore = workspace;
char *bytes = ((char*)(params));
char *output = new char[sizeof(${operation_name}_base::Params)];

View File

@ -116,13 +116,11 @@ DataTypeNames = {
DataTypeTag = {
cutlass.dtype.b1: "cutlass::uint1b_t",
cutlass.dtype.u2: "cutlass::uint2b_t",
cutlass.dtype.u4: "cutlass::uint4b_t",
cutlass.dtype.u8: "uint8_t",
cutlass.dtype.u16: "uint16_t",
cutlass.dtype.u32: "uint32_t",
cutlass.dtype.u64: "uint64_t",
cutlass.dtype.s2: "cutlass::int2b_t",
cutlass.dtype.s4: "cutlass::int4b_t",
cutlass.int8: "int8_t",
cutlass.dtype.s16: "int16_t",
@ -138,13 +136,11 @@ DataTypeTag = {
cutlass.dtype.cf32: "cutlass::complex<float>",
cutlass.dtype.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
cutlass.dtype.cf64: "cutlass::complex<double>",
cutlass.dtype.cu2: "cutlass::complex<cutlass::uint2b_t>",
cutlass.dtype.cu4: "cutlass::complex<cutlass::uint4b_t>",
cutlass.dtype.cu8: "cutlass::complex<cutlass::uint8_t>",
cutlass.dtype.cu16: "cutlass::complex<cutlass::uint16_t>",
cutlass.dtype.cu32: "cutlass::complex<cutlass::uint32_t>",
cutlass.dtype.cu64: "cutlass::complex<cutlass::uint64_t>",
cutlass.dtype.cs2: "cutlass::complex<cutlass::int2b_t>",
cutlass.dtype.cs4: "cutlass::complex<cutlass::int4b_t>",
cutlass.dtype.cs8: "cutlass::complex<cutlass::int8_t>",
cutlass.dtype.cs16: "cutlass::complex<cutlass::int16_t>",

View File

@ -1,4 +1,4 @@
pushd $CUTLASS_PATH/examples/40_cutlass_py/
pushd $CUTLASS_PATH/examples/40_cutlass_py/customizable
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1