releaase 2.11 (#703)
This commit is contained in:
@ -17,7 +17,8 @@ from library import *
|
||||
class Conv2dOperation:
|
||||
#
|
||||
def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
|
||||
stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1):
|
||||
stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \
|
||||
group_mode = GroupMode.NoneGroup):
|
||||
|
||||
self.operation_kind = OperationKind.Conv2d
|
||||
self.arch = arch
|
||||
@ -31,6 +32,7 @@ class Conv2dOperation:
|
||||
self.iterator_algorithm = iterator_algorithm
|
||||
self.stride_support = stride_support
|
||||
self.swizzling_functor = swizzling_functor
|
||||
self.group_mode = group_mode
|
||||
#
|
||||
def is_complex(self):
|
||||
complex_operators = [
|
||||
@ -95,17 +97,18 @@ class Conv2dOperation:
|
||||
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
|
||||
threadblock = "%dx%d_%dx%d" % (
|
||||
self.tile_description.threadblock_shape[0],
|
||||
self.tile_description.threadblock_shape[1],
|
||||
self.tile_description.threadblock_shape[2],
|
||||
self.tile_description.stages
|
||||
)
|
||||
threadblock = self.tile_description.procedural_name()
|
||||
|
||||
# grouped conv
|
||||
if self.group_mode != GroupMode.NoneGroup:
|
||||
group_conv_name = f"{GroupModeNames[self.group_mode]}_"
|
||||
else:
|
||||
group_conv_name = ""
|
||||
|
||||
if self.stride_support == StrideSupport.Unity:
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}"
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}"
|
||||
else:
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}"
|
||||
|
||||
return SubstituteTemplate(
|
||||
configuration_name,
|
||||
@ -115,6 +118,7 @@ class Conv2dOperation:
|
||||
'threadblock': threadblock,
|
||||
'layout': self.layout_name(),
|
||||
'alignment': "%d" % self.A.alignment,
|
||||
'group_conv_name': group_conv_name
|
||||
}
|
||||
)
|
||||
|
||||
@ -162,7 +166,77 @@ class EmitConv2dInstance:
|
||||
${align_b}
|
||||
>::Kernel;
|
||||
"""
|
||||
self.template_group_conv = """
|
||||
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}<
|
||||
${element_a},
|
||||
${layout_a},
|
||||
${element_b},
|
||||
${layout_b},
|
||||
${element_c},
|
||||
${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
||||
${stages},
|
||||
${math_operator},
|
||||
${group_mode},
|
||||
${iterator_algorithm},
|
||||
${stride_support},
|
||||
${align_a},
|
||||
${align_b}
|
||||
>::Kernel;
|
||||
"""
|
||||
self.template_depthwise_direct_conv = """
|
||||
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}<
|
||||
${element_a},
|
||||
${layout_a},
|
||||
${element_b},
|
||||
${layout_b},
|
||||
${element_c},
|
||||
${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>,
|
||||
cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue},
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
|
||||
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
|
||||
1,
|
||||
${threadblock_output_shape_n},
|
||||
${threadblock_output_shape_p},
|
||||
${threadblock_output_shape_q}>,
|
||||
${stages},
|
||||
${math_operator},
|
||||
${iterator_algorithm},
|
||||
${stride_support},
|
||||
cutlass::MatrixShape<${stride_r}, ${stride_s}>,
|
||||
cutlass::MatrixShape<${dilation_r}, ${dilation_s}>
|
||||
>::Kernel;
|
||||
"""
|
||||
|
||||
def emit(self, operation):
|
||||
|
||||
@ -206,7 +280,32 @@ class EmitConv2dInstance:
|
||||
'align_b': str(operation.B.alignment),
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.template, values)
|
||||
if operation.group_mode == GroupMode.NoneGroup:
|
||||
return SubstituteTemplate(self.template, values)
|
||||
|
||||
elif operation.group_mode == GroupMode.Depthwise:
|
||||
values['group_mode'] = GroupModeTag[operation.group_mode]
|
||||
# Setup other template params
|
||||
values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0])
|
||||
values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1])
|
||||
values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2])
|
||||
|
||||
values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3])
|
||||
|
||||
values['filter_shape_r'] = str(operation.tile_description.filter_shape[0])
|
||||
values['filter_shape_s'] = str(operation.tile_description.filter_shape[1])
|
||||
|
||||
values['stride_r'] = str(operation.tile_description.stride[0])
|
||||
values['stride_s'] = str(operation.tile_description.stride[1])
|
||||
|
||||
values['dilation_r'] = str(operation.tile_description.dilation[0])
|
||||
values['dilation_s'] = str(operation.tile_description.dilation[1])
|
||||
|
||||
return SubstituteTemplate(self.template_depthwise_direct_conv, values)
|
||||
|
||||
else:
|
||||
values['group_mode'] = GroupModeTag[operation.group_mode]
|
||||
return SubstituteTemplate(self.template_group_conv, values)
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
@ -292,6 +391,16 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
Operation_${operation_name}>(
|
||||
"${operation_name}"));
|
||||
|
||||
"""
|
||||
|
||||
self.configuration_direct_conv_instance = """
|
||||
using Operation_${operation_name} = cutlass::conv::device::DirectConvolution<
|
||||
${operation_name}>;
|
||||
|
||||
manifest.append(new cutlass::library::DirectConv2dOperation<
|
||||
Operation_${operation_name}>(
|
||||
"${operation_name}"));
|
||||
|
||||
"""
|
||||
|
||||
self.configuration_epilogue = """
|
||||
@ -334,10 +443,16 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
}))
|
||||
|
||||
for operation in self.operations:
|
||||
self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name()
|
||||
}))
|
||||
if operation.group_mode == GroupMode.Depthwise:
|
||||
self.configuration_file.write(SubstituteTemplate(self.configuration_direct_conv_instance, {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name()
|
||||
}))
|
||||
else:
|
||||
self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name()
|
||||
}))
|
||||
|
||||
self.configuration_file.write(self.configuration_epilogue)
|
||||
self.configuration_file.write(self.epilogue_template)
|
||||
|
||||
@ -11,6 +11,8 @@ import argparse
|
||||
|
||||
from library import *
|
||||
from manifest import *
|
||||
from itertools import product
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
@ -49,6 +51,8 @@ def EpilogueAlignment(max_alignment, tile, epilogue_steps = 8):
|
||||
def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
|
||||
alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
|
||||
swizzling_functor = SwizzlingFunctor.Identity8):
|
||||
# Use StreamK decomposition for basic GEMMs
|
||||
# swizzling_functor = SwizzlingFunctor.StreamK):
|
||||
|
||||
if complex_transforms is None:
|
||||
complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]
|
||||
@ -373,11 +377,26 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme
|
||||
|
||||
# Strided support for Analytic and Optimized Fprop
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
new_operations = [
|
||||
# None grouped kernel
|
||||
Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_),
|
||||
]
|
||||
|
||||
# Instance group conv kernel
|
||||
if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC:
|
||||
# SingleGroup kernel
|
||||
new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup))
|
||||
|
||||
# Analytic iterator supports MultipleGroup mode
|
||||
if iterator_algorithm == IteratorAlgorithm.Analytic:
|
||||
new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup))
|
||||
|
||||
for new_operation in new_operations:
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
#
|
||||
# Conv2d Dgrad
|
||||
@ -593,6 +612,62 @@ def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignme
|
||||
|
||||
return operations
|
||||
|
||||
# Convolution for Depthwise 2d conv
|
||||
def CreateDepthwiseConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \
|
||||
conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
|
||||
|
||||
element_a, element_b, element_c, element_epilogue = data_type
|
||||
|
||||
# iterator algorithm (FixedStrideDilation, Optimized)
|
||||
iterator_algorithms = [IteratorAlgorithm.FixedStrideDilation, IteratorAlgorithm.Optimized]
|
||||
|
||||
# by default, only generate the largest tile size, largest alignment, and optimized iterator
|
||||
if manifest.kernel_filter == '':
|
||||
tile_descriptions = [tile_descriptions[0],]
|
||||
alignment_constraints = [alignment_constraints[0],]
|
||||
|
||||
operations = []
|
||||
|
||||
for tile in tile_descriptions:
|
||||
for alignment in alignment_constraints:
|
||||
|
||||
alignment_c = min(8, alignment)
|
||||
|
||||
A = TensorDescription(element_a, layout[0], alignment)
|
||||
B = TensorDescription(element_b, layout[1], alignment)
|
||||
C = TensorDescription(element_c, layout[2], alignment_c)
|
||||
|
||||
swizzling_functor_ = swizzling_functor
|
||||
|
||||
if ConvKind.Fprop in conv_kinds:
|
||||
|
||||
# Strided support for Optimized and FixedStridedDilation Depthwise Conv
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
stride_support = StrideSupport.Strided
|
||||
if iterator_algorithm == IteratorAlgorithm.FixedStrideDilation:
|
||||
if tile.stride == [-1, -1] or tile.dilation == [-1,-1]:
|
||||
continue
|
||||
stride_support = StrideSupport.Fixed
|
||||
|
||||
if iterator_algorithm == IteratorAlgorithm.Optimized:
|
||||
if tile.stride != [-1, -1] or tile.dilation != [-1,-1]:
|
||||
continue
|
||||
new_operation = Conv2dOperation(ConvKind.Fprop,
|
||||
iterator_algorithm,
|
||||
tile.minimum_compute_capability,
|
||||
tile,
|
||||
A, B, C,
|
||||
element_epilogue,
|
||||
stride_support,
|
||||
epilogue_functor,
|
||||
swizzling_functor_,
|
||||
group_mode=GroupMode.Depthwise)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
return operations
|
||||
|
||||
###################################################################################################
|
||||
###################################################################################################
|
||||
@ -748,10 +823,83 @@ def GenerateSM60_Simt(manifest, cuda_version):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints)
|
||||
#
|
||||
def GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version):
|
||||
|
||||
math_instructions = [
|
||||
MathInstruction( \
|
||||
[1, 1, 1], \
|
||||
DataType.f16, DataType.f16, DataType.f16, \
|
||||
OpcodeClass.Simt, \
|
||||
MathOperation.multiply_add),
|
||||
]
|
||||
|
||||
min_cc = 60
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [8,]
|
||||
|
||||
filter_3x3 = [3, 3]
|
||||
filter_5x5 = [5, 5]
|
||||
|
||||
# [stride_h, stride_w]
|
||||
# [-1, -1] means all stride size.
|
||||
strides = [[-1,-1], [1, 1], [2, 2]]
|
||||
# [dilation_h, dilation_w]
|
||||
# [-1, -1] means all dilation size.
|
||||
dilations = [[-1,-1], [1, 1], [2, 2]]
|
||||
|
||||
#groups per thread block
|
||||
g16 = 16
|
||||
g32 = 32
|
||||
g64 = 64
|
||||
|
||||
#output shape per thread block
|
||||
npq_1x4x4 = [1, 4, 4]
|
||||
npq_1x8x8 = [1, 8, 8]
|
||||
npq_1x10x10 = [1, 10, 10]
|
||||
|
||||
tile_descriptions = []
|
||||
for math_inst in math_instructions:
|
||||
for stride, dilation in product(strides, dilations):
|
||||
tile_descriptions.extend([
|
||||
# filter3x3 ThreadBlock_output, filter, stage, warp
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
|
||||
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_3x3, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
|
||||
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_3x3, 4, stride, dilation,[4, 1, 1], math_inst, min_cc, max_cc),
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
|
||||
# filter5x5 ThreadBlock_output, filter, stage, warp
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
|
||||
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_5x5, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
|
||||
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc),
|
||||
Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc)
|
||||
])
|
||||
|
||||
data_type = [
|
||||
math_inst.element_a,
|
||||
math_inst.element_b,
|
||||
math_inst.element_accumulator,
|
||||
math_inst.element_accumulator,
|
||||
]
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateDepthwiseConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
#
|
||||
def GenerateSM60(manifest, cuda_version):
|
||||
GenerateSM60_Simt(manifest, cuda_version)
|
||||
GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version)
|
||||
|
||||
###################################################################################################
|
||||
###################################################################################################
|
||||
@ -3813,6 +3961,627 @@ def GenerateSM80(manifest, cuda_version):
|
||||
GenerateSM80_Simt_complex(manifest, cuda_version)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
math_inst = \
|
||||
MathInstruction( \
|
||||
[16, 8, 4], \
|
||||
DataType.f64, DataType.f64, DataType.f64, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add)
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 32, 16], 3, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 256, 16], 3, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64]
|
||||
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints)
|
||||
|
||||
#
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684_complex(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
math_inst = \
|
||||
MathInstruction( \
|
||||
[16, 8, 4], \
|
||||
DataType.f64, DataType.f64, DataType.f64, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add_complex)
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([128, 64, 8 ], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 128, 8 ], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 64, 8 ], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 64, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([16, 32, 8 ], 4, [1, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 16, 8 ], 4, [2, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([16, 32, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 16, 16], 3, [2, 1, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
|
||||
|
||||
complex_transforms = [
|
||||
(ComplexTransform.none, ComplexTransform.none),
|
||||
(ComplexTransform.conj, ComplexTransform.none),
|
||||
(ComplexTransform.none, ComplexTransform.conj),
|
||||
(ComplexTransform.conj, ComplexTransform.conj)
|
||||
]
|
||||
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints, complex_transforms)
|
||||
#
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
math_inst = \
|
||||
MathInstruction( \
|
||||
[16, 8, 4], \
|
||||
DataType.f64, DataType.f64, DataType.f64, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add_complex_gaussian)
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
|
||||
|
||||
complex_transforms = [
|
||||
(ComplexTransform.none, ComplexTransform.none),
|
||||
(ComplexTransform.conj, ComplexTransform.none),
|
||||
(ComplexTransform.none, ComplexTransform.conj),
|
||||
(ComplexTransform.conj, ComplexTransform.conj)
|
||||
]
|
||||
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints, complex_transforms)
|
||||
#
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
fill_modes = [
|
||||
FillMode.Lower, FillMode.Upper,
|
||||
]
|
||||
|
||||
math_inst = \
|
||||
MathInstruction( \
|
||||
[16, 8, 4], \
|
||||
DataType.f64, DataType.f64, DataType.f64, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add)
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [DataType.f64, DataType.f64, DataType.f64]
|
||||
|
||||
CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
|
||||
data_type, alignment_constraints, BlasMode.symmetric)
|
||||
#
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
fill_modes = [
|
||||
FillMode.Lower, FillMode.Upper,
|
||||
]
|
||||
|
||||
math_inst = \
|
||||
MathInstruction( \
|
||||
[16, 8, 4], \
|
||||
DataType.f64, DataType.f64, DataType.f64, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add_complex)
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [DataType.cf64, DataType.cf64, DataType.cf64]
|
||||
|
||||
# SYRK computation
|
||||
CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
|
||||
data_type, alignment_constraints, BlasMode.symmetric)
|
||||
|
||||
# HERK computation
|
||||
CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
|
||||
data_type, alignment_constraints, BlasMode.hermitian)
|
||||
|
||||
#
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
fill_modes = [
|
||||
FillMode.Lower, FillMode.Upper,
|
||||
]
|
||||
|
||||
math_inst = \
|
||||
MathInstruction( \
|
||||
[16, 8, 4], \
|
||||
DataType.f64, DataType.f64, DataType.f64, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add_complex_gaussian)
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [DataType.cf64, DataType.cf64, DataType.cf64]
|
||||
|
||||
complex_transforms = [ComplexTransform.none,]
|
||||
|
||||
# SYRK computation
|
||||
CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
|
||||
data_type, alignment_constraints, BlasMode.symmetric)
|
||||
|
||||
# HERK computation
|
||||
CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \
|
||||
data_type, alignment_constraints, BlasMode.hermitian)
|
||||
#
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
side_modes = [
|
||||
SideMode.Left, SideMode.Right,
|
||||
]
|
||||
|
||||
fill_modes = [
|
||||
FillMode.Lower, FillMode.Upper,
|
||||
]
|
||||
|
||||
diag_types = [
|
||||
DiagType.NonUnit, DiagType.Unit,
|
||||
]
|
||||
|
||||
math_inst = \
|
||||
MathInstruction( \
|
||||
[16, 8, 4], \
|
||||
DataType.f64, DataType.f64, DataType.f64, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add)
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64]
|
||||
|
||||
CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
|
||||
data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
side_modes = [
|
||||
SideMode.Left, SideMode.Right,
|
||||
]
|
||||
|
||||
fill_modes = [
|
||||
FillMode.Lower, FillMode.Upper,
|
||||
]
|
||||
|
||||
diag_types = [
|
||||
DiagType.NonUnit, DiagType.Unit,
|
||||
]
|
||||
|
||||
math_inst = \
|
||||
MathInstruction( \
|
||||
[16, 8, 4], \
|
||||
DataType.f64, DataType.f64, DataType.f64, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add_complex)
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
|
||||
|
||||
complex_transforms = [
|
||||
ComplexTransform.none, ComplexTransform.conj,
|
||||
]
|
||||
|
||||
CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
|
||||
data_type, alignment_constraints, complex_transforms)
|
||||
#
|
||||
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
side_modes = [
|
||||
SideMode.Left, SideMode.Right,
|
||||
]
|
||||
|
||||
fill_modes = [
|
||||
FillMode.Lower, FillMode.Upper,
|
||||
]
|
||||
|
||||
diag_types = [
|
||||
DiagType.NonUnit, DiagType.Unit,
|
||||
]
|
||||
|
||||
math_inst = \
|
||||
MathInstruction( \
|
||||
[16, 8, 4], \
|
||||
DataType.f64, DataType.f64, DataType.f64, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add_complex_gaussian)
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
|
||||
|
||||
complex_transforms = [
|
||||
ComplexTransform.none, ComplexTransform.conj,
|
||||
]
|
||||
|
||||
CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \
|
||||
data_type, alignment_constraints, complex_transforms)
|
||||
#
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684_symm(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
side_modes = [
|
||||
SideMode.Left, SideMode.Right,
|
||||
]
|
||||
|
||||
fill_modes = [
|
||||
FillMode.Lower, FillMode.Upper,
|
||||
]
|
||||
|
||||
math_inst = \
|
||||
MathInstruction( \
|
||||
[16, 8, 4], \
|
||||
DataType.f64, DataType.f64, DataType.f64, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add)
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64]
|
||||
|
||||
CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
|
||||
data_type, alignment_constraints, BlasMode.symmetric)
|
||||
#
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
side_modes = [
|
||||
SideMode.Left, SideMode.Right,
|
||||
]
|
||||
|
||||
fill_modes = [
|
||||
FillMode.Lower, FillMode.Upper,
|
||||
]
|
||||
|
||||
math_inst = \
|
||||
MathInstruction( \
|
||||
[16, 8, 4], \
|
||||
DataType.f64, DataType.f64, DataType.f64, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add_complex)
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
|
||||
|
||||
# SYMM computation
|
||||
CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
|
||||
data_type, alignment_constraints, BlasMode.symmetric)
|
||||
|
||||
# HEMM computation
|
||||
CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
|
||||
data_type, alignment_constraints, BlasMode.hermitian)
|
||||
#
|
||||
|
||||
#
|
||||
def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version):
|
||||
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 11, 8):
|
||||
return
|
||||
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
side_modes = [
|
||||
SideMode.Left, SideMode.Right,
|
||||
]
|
||||
|
||||
fill_modes = [
|
||||
FillMode.Lower, FillMode.Upper,
|
||||
]
|
||||
|
||||
math_inst = \
|
||||
MathInstruction( \
|
||||
[16, 8, 4], \
|
||||
DataType.f64, DataType.f64, DataType.f64, \
|
||||
OpcodeClass.TensorOp, \
|
||||
MathOperation.multiply_add_complex_gaussian)
|
||||
|
||||
min_cc = 90
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc),
|
||||
#TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64]
|
||||
|
||||
complex_transforms = [ComplexTransform.none,]
|
||||
|
||||
# SYMM computation
|
||||
CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
|
||||
data_type, alignment_constraints, BlasMode.symmetric)
|
||||
|
||||
# HEMM computation
|
||||
CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \
|
||||
data_type, alignment_constraints, BlasMode.hermitian)
|
||||
#
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
def GenerateSM90(manifest, cuda_version):
|
||||
|
||||
GenerateSM90_TensorOp_1684(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_complex(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version)
|
||||
|
||||
GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_symm(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -3842,6 +4611,8 @@ if __name__ == "__main__":
|
||||
GenerateSM70(manifest, args.cuda_version)
|
||||
GenerateSM75(manifest, args.cuda_version)
|
||||
GenerateSM80(manifest, args.cuda_version)
|
||||
GenerateSM90(manifest, args.cuda_version)
|
||||
|
||||
if 'library' in args.generator_target.split(','):
|
||||
manifest.emit(GeneratorTarget.Library)
|
||||
|
||||
|
||||
@ -471,9 +471,11 @@ SharedMemPerCC = {
|
||||
70: 96, # 96KB of SMEM
|
||||
72: 96, # 96KB of SMEM
|
||||
75: 64, # 64KB of SMEM
|
||||
80: 160, # 164KB of SMEM - 4KB reserved for the driver
|
||||
86: 100, # 100KB of SMEM
|
||||
87: 160, # 164KB of SMEM - 4KB reserved for the driver
|
||||
80: 163, # 163KB of SMEM - 1KB reserved for the driver
|
||||
86: 99, # 99KB of SMEM - 1KB reserved for the driver
|
||||
87: 163, # 163KB of SMEM - 1KB reserved for the driver
|
||||
89: 99, # 99KB of SMEM - 1KB reserved for the driver
|
||||
90: 227, # 227KB of SMEM - 1KB reserved for the driver
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
@ -561,7 +563,8 @@ class SwizzlingFunctor(enum.Enum):
|
||||
StridedDgradIdentity1 = enum_auto()
|
||||
StridedDgradIdentity4 = enum_auto()
|
||||
StridedDgradHorizontal = enum_auto()
|
||||
|
||||
StreamK = enum_auto()
|
||||
|
||||
#
|
||||
SwizzlingFunctorTag = {
|
||||
SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>',
|
||||
@ -572,6 +575,7 @@ SwizzlingFunctorTag = {
|
||||
SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>',
|
||||
SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>',
|
||||
SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle',
|
||||
SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK',
|
||||
}
|
||||
|
||||
#
|
||||
@ -618,38 +622,65 @@ class IteratorAlgorithm(enum.Enum):
|
||||
Optimized = enum_auto()
|
||||
FixedChannels = enum_auto()
|
||||
FewChannels = enum_auto()
|
||||
FixedStrideDilation = enum_auto()
|
||||
|
||||
#
|
||||
IteratorAlgorithmTag = {
|
||||
IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic',
|
||||
IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized',
|
||||
IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels',
|
||||
IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels'
|
||||
IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels',
|
||||
IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation'
|
||||
}
|
||||
|
||||
IteratorAlgorithmNames = {
|
||||
IteratorAlgorithm.Analytic: 'analytic',
|
||||
IteratorAlgorithm.Optimized: 'optimized',
|
||||
IteratorAlgorithm.FixedChannels: 'fixed_channels',
|
||||
IteratorAlgorithm.FewChannels: 'few_channels'
|
||||
IteratorAlgorithm.FewChannels: 'few_channels',
|
||||
IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation'
|
||||
}
|
||||
|
||||
#
|
||||
class StrideSupport(enum.Enum):
|
||||
Strided = enum_auto()
|
||||
Unity = enum_auto()
|
||||
Fixed = enum_auto()
|
||||
|
||||
#
|
||||
StrideSupportTag = {
|
||||
StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided',
|
||||
StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity',
|
||||
StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed'
|
||||
}
|
||||
|
||||
StrideSupportNames = {
|
||||
StrideSupport.Strided: '',
|
||||
StrideSupport.Unity: 'unity_stride',
|
||||
StrideSupport.Fixed: 'fixed_stride'
|
||||
}
|
||||
|
||||
#
|
||||
class GroupMode(enum.Enum):
|
||||
NoneGroup = enum_auto() # dense conv (G=1)
|
||||
SingleGroup = enum_auto() # grouped convolution (single group per CTA)
|
||||
MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA)
|
||||
Depthwise = enum_auto() # Depthwise convolution ( C=K=G )
|
||||
|
||||
#
|
||||
GroupModeTag = {
|
||||
GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone',
|
||||
GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup',
|
||||
GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup',
|
||||
GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise',
|
||||
}
|
||||
|
||||
GroupModeNames = {
|
||||
GroupMode.NoneGroup: '',
|
||||
GroupMode.SingleGroup: 'single_group',
|
||||
GroupMode.MultipleGroup: 'multiple_group',
|
||||
GroupMode.Depthwise: 'depthwise',
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
@ -677,6 +708,39 @@ class TileDescription:
|
||||
def procedural_name(self):
|
||||
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
|
||||
|
||||
#
|
||||
class Direct2dConvFixedStrideDilationTileDescription:
|
||||
def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute):
|
||||
self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]]
|
||||
self.threadblock_output_shape = threadblock_output_shape
|
||||
self.filter_shape = filter_shape
|
||||
self.stages = stages
|
||||
self.warp_count = warp_count
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.math_instruction = math_instruction
|
||||
self.minimum_compute_capability = min_compute
|
||||
self.maximum_compute_capability = max_compute
|
||||
|
||||
def procedural_name(self):
|
||||
str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],
|
||||
self.threadblock_shape[1],
|
||||
self.threadblock_shape[2],
|
||||
self.threadblock_output_shape[0],
|
||||
self.threadblock_output_shape[1],
|
||||
self.threadblock_output_shape[2],
|
||||
self.threadblock_output_shape[3],
|
||||
self.stages,
|
||||
self.filter_shape[0],
|
||||
self.filter_shape[1])
|
||||
# Fixed Strided and dilation
|
||||
if self.stride != [-1, -1] and self.dilation != [-1, -1]:
|
||||
str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0],
|
||||
self.stride[1],
|
||||
self.dilation[0],
|
||||
self.dilation[1])
|
||||
return str_name
|
||||
|
||||
#
|
||||
class TensorDescription:
|
||||
def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none):
|
||||
|
||||
@ -85,18 +85,11 @@ You can run the PyCUTLASS on NGC PyTorch container.
|
||||
```shell
|
||||
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:22.09-py3
|
||||
```
|
||||
PyCUTLASS requires additional dependency Boost C++ library, which can be installed with
|
||||
```bash
|
||||
apt-get update
|
||||
apt-get -y install libboost-all-dev
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Environment variables
|
||||
PyCUTLASSS requires two environment variables:
|
||||
* `CUTLASS_PATH`: the root directory of CUTLASS
|
||||
* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed
|
||||
* `CUTLASS_PATH`: the root directory of CUTLASS. You can set this from the location at which you cloned CUTLASS via: `export CUTLASS_PATH=$(pwd)`.
|
||||
* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed. If running in bash with `nvcc` installed under a CUDA toolkit, you can set this to the location of your `nvcc` installation via: `export CUDA_INSTALL_PATH=$(which nvcc | awk -F'/bin/nvcc' '{print $1}')`
|
||||
|
||||
After setting these two environment variables, PyCUTLASS can be installed with
|
||||
```shell
|
||||
|
||||
@ -38,6 +38,7 @@
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/params_universal_base.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
@ -104,16 +105,12 @@ public:
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
struct Arguments : UniversalArgumentsBase {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
void const * ptr_A;
|
||||
@ -124,7 +121,6 @@ public:
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename LayoutA::Stride stride_a;
|
||||
typename LayoutB::Stride stride_b;
|
||||
@ -145,8 +141,6 @@ public:
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
@ -174,12 +168,10 @@ public:
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
|
||||
epilogue_visitor(epilogue_visitor),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
|
||||
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
@ -212,12 +204,10 @@ public:
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
|
||||
epilogue_visitor(epilogue_visitor),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
|
||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
@ -248,11 +238,19 @@ public:
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
struct Params : UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC> {
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
using ParamsBase = UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC>;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
@ -261,10 +259,6 @@ public:
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
void * ptr_C;
|
||||
@ -273,7 +267,6 @@ public:
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
int * ptr_gather_A_indices;
|
||||
int * ptr_gather_B_indices;
|
||||
@ -285,47 +278,21 @@ public:
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_C(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_C(0),
|
||||
batch_stride_D(0),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
ptr_scatter_D_indices(nullptr),
|
||||
semaphore(nullptr) { }
|
||||
/// Default constructor
|
||||
Params() = default;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
int gemm_k_size,
|
||||
void *workspace = nullptr
|
||||
int device_sms,
|
||||
int sm_occupancy
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
ParamsBase(args, device_sms, sm_occupancy),
|
||||
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
|
||||
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
|
||||
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
|
||||
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
|
||||
epilogue_visitor(args.epilogue_visitor),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(gemm_k_size),
|
||||
ptr_A(const_cast<void *>(args.ptr_A)),
|
||||
ptr_B(const_cast<void *>(args.ptr_B)),
|
||||
ptr_C(const_cast<void *>(args.ptr_C)),
|
||||
@ -333,11 +300,9 @@ public:
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
|
||||
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)),
|
||||
semaphore(static_cast<int *>(workspace)) {
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)) {
|
||||
|
||||
}
|
||||
|
||||
@ -358,7 +323,6 @@ public:
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
|
||||
epilogue_visitor = args.epilogue_visitor;
|
||||
|
||||
@ -466,12 +430,6 @@ public:
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
@ -38,11 +38,19 @@
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/conv/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include <boost/core/demangle.hpp>
|
||||
#include <cxxabi.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
std::string demangle(const char* mangled_name) {
|
||||
std::size_t len = 0;
|
||||
int status = 0;
|
||||
std::unique_ptr<char> ptr(
|
||||
__cxxabiv1::__cxa_demangle(mangled_name, nullptr, &len, &status));
|
||||
return ptr.get();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void bind_identity_swizzle(py::module & m, std::string name) {
|
||||
py::class_<T>(m, name.c_str(),
|
||||
@ -80,7 +88,7 @@ void bind_identity_swizzle(py::module & m, std::string name) {
|
||||
py::arg("tiled_shape"),
|
||||
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
|
||||
.def("tag", [](const T & swizzle){
|
||||
return boost::core::demangle(typeid(T).name());
|
||||
return demangle(typeid(T).name());
|
||||
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
|
||||
}
|
||||
|
||||
@ -101,7 +109,7 @@ void bind_swizzle(py::module & m, std::string name, std::string doc) {
|
||||
py::arg("tiled_shape"),
|
||||
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
|
||||
.def("tag", [](const T & swizzle){
|
||||
return boost::core::demangle(typeid(T).name());
|
||||
return demangle(typeid(T).name());
|
||||
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
|
||||
}
|
||||
|
||||
@ -124,7 +132,7 @@ void bind_dgrad_swizzle(py::module & m, std::string name) {
|
||||
}, py::arg("tiled_shape"),
|
||||
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
|
||||
.def("tag", [](const T & swizzle){
|
||||
return boost::core::demangle(typeid(T).name());
|
||||
return demangle(typeid(T).name());
|
||||
}, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc");
|
||||
}
|
||||
|
||||
|
||||
@ -69,9 +69,12 @@ def get_gemm_arguments(epilogue_functor):
|
||||
|
||||
class _GemmArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
# Arguments from UniversalArgumentsBase
|
||||
("mode", ctypes.c_int),
|
||||
("problem_size", GemmCoord_),
|
||||
("batch_count", ctypes.c_int),
|
||||
("batch_stride_D", ctypes.c_longlong),
|
||||
# Remaining arguments
|
||||
("epilogue", _EpilogueOutputOpParams),
|
||||
("ptr_A", ctypes.c_void_p),
|
||||
("ptr_B", ctypes.c_void_p),
|
||||
@ -80,7 +83,6 @@ def get_gemm_arguments(epilogue_functor):
|
||||
("batch_stride_A", ctypes.c_longlong),
|
||||
("batch_stride_B", ctypes.c_longlong),
|
||||
("batch_stride_C", ctypes.c_longlong),
|
||||
("batch_stride_D", ctypes.c_longlong),
|
||||
("stride_a", ctypes.c_longlong),
|
||||
("stride_b", ctypes.c_longlong),
|
||||
("stride_c", ctypes.c_longlong),
|
||||
|
||||
@ -229,7 +229,7 @@ class GemmArguments(ArgumentBase):
|
||||
elif operand in ["c", "d"]:
|
||||
tensor_coord = problem_size.mn()
|
||||
else:
|
||||
raise ValueError("unknonw operand: " + operand)
|
||||
raise ValueError("unknown operand: " + operand)
|
||||
|
||||
layout = tensor_layout.packed(tensor_coord)
|
||||
|
||||
@ -245,22 +245,27 @@ class GemmArguments(ArgumentBase):
|
||||
)
|
||||
if self.gemm_mode == cutlass.gemm.Mode.Array:
|
||||
arguments = self.operation.argument_type(
|
||||
self.gemm_mode, problem_size_, self.batch_count, self.output_op,
|
||||
# Arguments from UniversalArgumentsBase
|
||||
self.gemm_mode, problem_size_, self.batch_count, 0,
|
||||
# Remaining arguments
|
||||
self.output_op,
|
||||
int(self.ptr_A_array_buffer.ptr),
|
||||
int(self.ptr_B_array_buffer.ptr),
|
||||
int(self.ptr_C_array_buffer.ptr),
|
||||
int(self.ptr_D_array_buffer.ptr),
|
||||
0, 0, 0, 0,
|
||||
0, 0, 0,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
0, 0, 0
|
||||
)
|
||||
else:
|
||||
arguments = self.operation.argument_type(
|
||||
self.gemm_mode, problem_size_, self.batch_count, self.output_op,
|
||||
# Arguments from UniversalArgumentsBase
|
||||
self.gemm_mode, problem_size_, self.batch_count, self.batched_stride_D,
|
||||
# Remaining arguments
|
||||
self.output_op,
|
||||
int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D),
|
||||
self.batched_stride_A, self.batched_stride_B, self.batched_stride_C,
|
||||
self.batched_stride_D,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
self.lda, self.ldb, self.ldc, self.ldd,
|
||||
0, 0, 0
|
||||
@ -299,8 +304,7 @@ class GemmArguments(ArgumentBase):
|
||||
|
||||
arguments, grid_tiled_shape, gemm_k_size = self.arguments
|
||||
res_arg = self.operation.rt_module.get_args(
|
||||
ctypes.byref(arguments), ctypes.byref(grid_tiled_shape),
|
||||
gemm_k_size, ctypes.c_void_p(int(device_workspace)))
|
||||
ctypes.byref(arguments), ctypes.c_void_p(int(device_workspace)))
|
||||
host_workspace = bytearray(res_arg.contents)
|
||||
|
||||
device_workspace = None
|
||||
@ -582,10 +586,15 @@ extern "C" {
|
||||
}
|
||||
|
||||
// Get the params as byte array
|
||||
char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, \
|
||||
cutlass::gemm::GemmCoord* grid_tiled_shape, int gemm_k_size, int* workspace){
|
||||
char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int* workspace){
|
||||
${operation_name}_base::Params* params;
|
||||
params = new ${operation_name}_base::Params(*argument, *grid_tiled_shape, gemm_k_size, workspace);
|
||||
params = new ${operation_name}_base::Params(*argument,
|
||||
-1, // SM count. Only used for stream-K
|
||||
-1 // Occupancy. Only used for stream-K
|
||||
);
|
||||
|
||||
// Semaphore holds the pointer to the workspace in the Params struct
|
||||
params->semaphore = workspace;
|
||||
|
||||
char *bytes = ((char*)(params));
|
||||
char *output = new char[sizeof(${operation_name}_base::Params)];
|
||||
|
||||
@ -116,13 +116,11 @@ DataTypeNames = {
|
||||
|
||||
DataTypeTag = {
|
||||
cutlass.dtype.b1: "cutlass::uint1b_t",
|
||||
cutlass.dtype.u2: "cutlass::uint2b_t",
|
||||
cutlass.dtype.u4: "cutlass::uint4b_t",
|
||||
cutlass.dtype.u8: "uint8_t",
|
||||
cutlass.dtype.u16: "uint16_t",
|
||||
cutlass.dtype.u32: "uint32_t",
|
||||
cutlass.dtype.u64: "uint64_t",
|
||||
cutlass.dtype.s2: "cutlass::int2b_t",
|
||||
cutlass.dtype.s4: "cutlass::int4b_t",
|
||||
cutlass.int8: "int8_t",
|
||||
cutlass.dtype.s16: "int16_t",
|
||||
@ -138,13 +136,11 @@ DataTypeTag = {
|
||||
cutlass.dtype.cf32: "cutlass::complex<float>",
|
||||
cutlass.dtype.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
|
||||
cutlass.dtype.cf64: "cutlass::complex<double>",
|
||||
cutlass.dtype.cu2: "cutlass::complex<cutlass::uint2b_t>",
|
||||
cutlass.dtype.cu4: "cutlass::complex<cutlass::uint4b_t>",
|
||||
cutlass.dtype.cu8: "cutlass::complex<cutlass::uint8_t>",
|
||||
cutlass.dtype.cu16: "cutlass::complex<cutlass::uint16_t>",
|
||||
cutlass.dtype.cu32: "cutlass::complex<cutlass::uint32_t>",
|
||||
cutlass.dtype.cu64: "cutlass::complex<cutlass::uint64_t>",
|
||||
cutlass.dtype.cs2: "cutlass::complex<cutlass::int2b_t>",
|
||||
cutlass.dtype.cs4: "cutlass::complex<cutlass::int4b_t>",
|
||||
cutlass.dtype.cs8: "cutlass::complex<cutlass::int8_t>",
|
||||
cutlass.dtype.cs16: "cutlass::complex<cutlass::int16_t>",
|
||||
|
||||
2
tools/library/scripts/pycutlass/test/example/run_all_example.sh
Normal file → Executable file
2
tools/library/scripts/pycutlass/test/example/run_all_example.sh
Normal file → Executable file
@ -1,4 +1,4 @@
|
||||
pushd $CUTLASS_PATH/examples/40_cutlass_py/
|
||||
pushd $CUTLASS_PATH/examples/40_cutlass_py/customizable
|
||||
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user