CUTLASS 2.4 (Implicit GEMM convolution) (#147)

CUTLASS 2.4 (Implicit GEMM Convolution)

Co-authored-by: Manish Gupta <manigupta@nvidia.com>, Haicheng Wu <haichengw@nvidia.com>, Dustyn Blasig <dblasig@nvidia.com>, Andrew Kerr <akerr@nvidia.com>
This commit is contained in:
Manish Gupta
2020-11-19 21:25:25 -08:00
committed by GitHub
parent c2b80ad4e4
commit 6615010cd0
224 changed files with 43939 additions and 1061 deletions

View File

@ -11,7 +11,6 @@ import argparse
from library import *
from manifest import *
from gemm_operation import *
###################################################################################################
#
@ -118,10 +117,9 @@ def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_t
gemm_kinds = [GemmKind.PlanarComplex, GemmKind.PlanarComplexArray]
# by default, only generate the largest tile and largest alignment
# by default, planar complex gemm kernels are not generated
if manifest.args.kernels == '':
tile_descriptions = [tile_descriptions[0],]
alignment_constraints = [alignment_constraints[0],]
return
for gemm_kind in gemm_kinds:
for layout in layouts:
@ -141,6 +139,103 @@ def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_t
return
###########################################################################################################
# ConvolutionOperator support variations
# ____________________________________________________________________
# ConvolutionalOperator | Analytic | Optimized
# ____________________________________________________________________
# | Fprop | (strided) | (strided)
# | Dgrad | (strided, unity*) | (unity)
# | Wgrad | (strided) | (strided)
# ____________________________________________________________________
#
# Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
###########################################################################################################
# Convolution for 2D operations
def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination):
element_a, element_b, element_c, element_epilogue = data_type
# one exceptional case
alignment_c = min(8, alignment)
# iterator algorithm (analytic and optimized)
iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
# by default, only generate the largest tile size
if manifest.args.kernels == '':
tile_descriptions = [tile_descriptions[0],]
operations = []
for tile in tile_descriptions:
for conv_kind in conv_kinds:
for iterator_algorithm in iterator_algorithms:
A = TensorDescription(element_a, layout[0], alignment)
B = TensorDescription(element_b, layout[1], alignment)
C = TensorDescription(element_c, layout[2], alignment_c)
# unity stride only for Optimized Dgrad
if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad):
new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor)
manifest.append(new_operation)
operations.append(new_operation)
# strided dgrad is not supported by Optimized Dgrad
if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad):
continue
# strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic)
new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor)
manifest.append(new_operation)
operations.append(new_operation)
return operations
# Convolution for 3D operations
def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination):
element_a, element_b, element_c, element_epilogue = data_type
# one exceptional case
alignment_c = min(8, alignment)
# iterator algorithm (analytic and optimized)
iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
# by default, only generate the largest tile size
if manifest.args.kernels == '':
tile_descriptions = [tile_descriptions[0],]
operations = []
for tile in tile_descriptions:
for conv_kind in conv_kinds:
for iterator_algorithm in iterator_algorithms:
A = TensorDescription(element_a, layout, alignment)
B = TensorDescription(element_b, layout, alignment)
C = TensorDescription(element_c, layout, alignment_c)
# optimized conv3d iterator algorithm is only for Wgrad
if (iterator_algorithm == IteratorAlgorithm.Optimized) \
and ((conv_kind == ConvKind.Fprop) or (conv_kind == ConvKind.Dgrad)):
continue
# strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic)
new_operation = Conv3dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor)
manifest.append(new_operation)
operations.append(new_operation)
return operations
###################################################################################################
###################################################################################################
@ -191,11 +286,57 @@ def GenerateSM50_Simt(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)
if math_inst.element_a == DataType.f32:
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
#
#
def GenerateSM50_Simt_complex(manifest, args):
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
]
math_instructions = [
MathInstruction( \
[1, 1, 1], \
DataType.f32, DataType.f32, DataType.f32, \
OpcodeClass.Simt, \
MathOperation.multiply_add_complex),
]
min_cc = 50
max_cc = 1024
alignment_constraints = [1,]
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [
DataType.cf32,
DataType.cf32,
DataType.cf32,
DataType.cf32,
]
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
#
#
def GenerateSM50(manifest, args):
GenerateSM50_Simt(manifest, args)
GenerateSM50_Simt_complex(manifest, args)
###################################################################################################
###################################################################################################
@ -362,6 +503,9 @@ def GenerateSM70_TensorOp_884(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
if math_inst.element_a != math_inst.element_accumulator:
@ -375,6 +519,8 @@ def GenerateSM70_TensorOp_884(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
#
def GenerateSM70_PlanarComplexTensorOp_884(manifest, args):
@ -504,50 +650,10 @@ def GenerateSM70_WmmaTensorOp_161616(manifest, args):
#
##################################################################################################
#
def GenerateSM70_Simt_complex(manifest, args):
math_instructions = [
MathInstruction( \
[1, 1, 1], \
DataType.f32, DataType.f32, DataType.f32, \
OpcodeClass.Simt, \
MathOperation.multiply_add_complex),
]
min_cc = 70
max_cc = 1024
alignment_constraints = [1,]
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 16], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 16], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 32, 16], 2, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [
DataType.cf32,
DataType.cf32,
DataType.cf32,
DataType.cf32
]
complex_transforms = [
(ComplexTransform.none, ComplexTransform.none),
(ComplexTransform.conj, ComplexTransform.none),
(ComplexTransform.none, ComplexTransform.conj),
(ComplexTransform.conj, ComplexTransform.conj)
]
#
def GenerateSM70(manifest, args):
GenerateSM70_TensorOp_884(manifest, args)
GenerateSM70_PlanarComplexTensorOp_884(manifest, args)
GenerateSM70_Simt_complex(manifest, args)
# To limit build size, WMMA GEMMs are disabled for now.
#
@ -607,6 +713,9 @@ def GenerateSM75_TensorOp_1688(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
if math_inst.element_a != math_inst.element_accumulator:
@ -620,6 +729,8 @@ def GenerateSM75_TensorOp_1688(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
#
#
@ -738,6 +849,10 @@ def GenerateSM75_TensorOp_8816_TN(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
if math_inst.element_a != math_inst.element_accumulator:
@ -753,6 +868,9 @@ def GenerateSM75_TensorOp_8816_TN(manifest, args):
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
op.C.alignment = 16
@ -794,6 +912,8 @@ def GenerateSM75_TensorOp_8816_Interleaved(manifest, args):
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
@ -809,9 +929,13 @@ def GenerateSM75_TensorOp_8816_Interleaved(manifest, args):
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
# conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32)
#
# operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
# data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
op.C.alignment = 8
#
#
@ -862,6 +986,10 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
if math_inst.element_a != math_inst.element_accumulator:
@ -877,6 +1005,9 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args):
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
op.C.alignment = 8
@ -920,9 +1051,9 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, args):
TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 128], 2, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 128], 2, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
]
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
@ -938,9 +1069,13 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, args):
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
# conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64)
#
# operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
# data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
op.C.alignment = 16
#
#
@ -1074,6 +1209,8 @@ def GenerateSM75_Simt_complex(manifest, args):
(ComplexTransform.conj, ComplexTransform.conj)
]
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
#
def GenerateSM75(manifest, args):
@ -1124,6 +1261,7 @@ def GenerateSM80_TensorOp_16816(manifest, args):
min_cc = 80
max_cc = 1024
max_cc_smem_limited = 80
alignment_constraints = [8, 4, 2]
@ -1137,10 +1275,10 @@ def GenerateSM80_TensorOp_16816(manifest, args):
TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
@ -1157,6 +1295,10 @@ def GenerateSM80_TensorOp_16816(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type, 8)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
if math_inst.element_a != math_inst.element_accumulator:
@ -1170,6 +1312,8 @@ def GenerateSM80_TensorOp_16816(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type_mixed, 8)
#
#
@ -1205,22 +1349,23 @@ def GenerateSM80_SparseTensorOp_16832(manifest, args):
min_cc = 80
max_cc = 1024
max_cc_smem_limited = 80
alignment_constraints = [8, 4, 2]
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
]
@ -1348,6 +1493,7 @@ def GenerateSM80_TensorOp_16832_TN(manifest, args):
min_cc = 80
max_cc = 1024
max_cc_smem_limited = 80
alignment_constraints = [16,]
@ -1361,10 +1507,10 @@ def GenerateSM80_TensorOp_16832_TN(manifest, args):
TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
@ -1382,6 +1528,13 @@ def GenerateSM80_TensorOp_16832_TN(manifest, args):
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
op.C.alignment = 16
@ -1409,21 +1562,22 @@ def GenerateSM80_SparseTensorOp_16864_TN(manifest, args):
min_cc = 80
max_cc = 1024
max_cc_smem_limited = 80
alignment_constraints = [16,]
tile_descriptions = [
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
]
@ -1489,10 +1643,14 @@ def GenerateSM80_TensorOp_16832_Interleaved(manifest, args):
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
# conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32)
#
# operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
# data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
op.C.alignment = 8
#
#
@ -1520,6 +1678,7 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args):
min_cc = 80
max_cc = 1024
max_cc_smem_limited = 80
alignment_constraints = [32,]
@ -1533,14 +1692,14 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args):
TileDescription([128, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 128], 10, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
]
data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32]
@ -1582,20 +1741,21 @@ def GenerateSM80_SparseTensorOp_168128_TN(manifest, args):
min_cc = 80
max_cc = 1024
max_cc_smem_limited = 80
alignment_constraints = [32,]
tile_descriptions = [
TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 256], 3, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 256], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 256], 3, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 128, 256], 6, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 128, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc),
]
@ -1655,9 +1815,7 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, args):
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 128], 10, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32]
@ -1666,7 +1824,12 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, args):
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
# conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64)
#
# operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
# data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
op.C.alignment = 16
#
@ -1744,6 +1907,7 @@ def GenerateSM80_TensorOp_1688(manifest, args):
min_cc = 80
max_cc = 1024
max_cc_smem_limited = 80
alignment_constraints = [4, 2, 1]
@ -1757,11 +1921,11 @@ def GenerateSM80_TensorOp_1688(manifest, args):
TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
@ -1787,6 +1951,10 @@ def GenerateSM80_TensorOp_1688(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 4)
#
#
@ -1822,6 +1990,7 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args):
min_cc = 80
max_cc = 1024
max_cc_smem_limited = 80
alignment_constraints = [4, 2, 1]
@ -1835,11 +2004,11 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args):
TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
@ -1850,6 +2019,8 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4)
#
#
@ -1875,22 +2046,23 @@ def GenerateSM80_SparseTensorOp_16816_fast_math(manifest, args):
min_cc = 80
max_cc = 1024
max_cc_smem_limited = 80
alignment_constraints = [4, 2, 1]
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
]
@ -1971,13 +2143,14 @@ def GenerateSM80_TensorOp_884(manifest, args):
min_cc = 80
max_cc = 1024
max_cc_smem_limited = 80
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
@ -2090,7 +2263,7 @@ def GenerateSM80_TensorOp_884_complex_gaussian(manifest, args):
###################################################################################################
#
def GenerateSM80_Simt(manifest, args):
def GenerateSM80_Simt_f32(manifest, args):
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
@ -2136,8 +2309,55 @@ def GenerateSM80_Simt(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
#
#
def GenerateSM80_Simt_f64(manifest, args):
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
]
math_instructions = [
MathInstruction( \
[1, 1, 1], \
DataType.f64, DataType.f64, DataType.f64, \
OpcodeClass.Simt, \
MathOperation.multiply_add),
]
min_cc = 80
max_cc = 1024
alignment_constraints = [1,]
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([128, 128, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 128, 8], 5, [1, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_accumulator,
]
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)
#
##################################################################################################
#
def GenerateSM80_Simt_complex(manifest, args):
@ -2154,7 +2374,29 @@ def GenerateSM80_Simt_complex(manifest, args):
alignment_constraints = [1,]
data_type = [
DataType.cf32,
DataType.cf32,
DataType.cf32,
DataType.cf32
]
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
]
complex_transforms = [
(ComplexTransform.none, ComplexTransform.none),
(ComplexTransform.conj, ComplexTransform.none),
(ComplexTransform.none, ComplexTransform.conj),
(ComplexTransform.conj, ComplexTransform.conj)
]
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc),
@ -2165,20 +2407,11 @@ def GenerateSM80_Simt_complex(manifest, args):
TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [
DataType.cf32,
DataType.cf32,
DataType.cf32,
DataType.cf32
]
complex_transforms = [
(ComplexTransform.none, ComplexTransform.none),
(ComplexTransform.conj, ComplexTransform.none),
(ComplexTransform.none, ComplexTransform.conj),
(ComplexTransform.conj, ComplexTransform.conj)
]
CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints, complex_transforms)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
#
###################################################################################################
@ -2202,7 +2435,8 @@ def GenerateSM80(manifest, args):
GenerateSM80_SparseTensorOp_168128_TN(manifest, args)
GenerateSM80_TensorOp_16864_Interleaved(manifest, args)
GenerateSM80_TensorOp_168256(manifest, args)
GenerateSM80_Simt(manifest, args)
GenerateSM80_Simt_f32(manifest, args)
GenerateSM80_Simt_f64(manifest, args)
GenerateSM80_Simt_complex(manifest, args)
###################################################################################################