CUTLASS 2.4 (Implicit GEMM convolution) (#147)
CUTLASS 2.4 (Implicit GEMM Convolution) Co-authored-by: Manish Gupta <manigupta@nvidia.com>, Haicheng Wu <haichengw@nvidia.com>, Dustyn Blasig <dblasig@nvidia.com>, Andrew Kerr <akerr@nvidia.com>
This commit is contained in:
@ -11,7 +11,6 @@ import argparse
|
||||
|
||||
from library import *
|
||||
from manifest import *
|
||||
from gemm_operation import *
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
@ -118,10 +117,9 @@ def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_t
|
||||
|
||||
gemm_kinds = [GemmKind.PlanarComplex, GemmKind.PlanarComplexArray]
|
||||
|
||||
# by default, only generate the largest tile and largest alignment
|
||||
# by default, planar complex gemm kernels are not generated
|
||||
if manifest.args.kernels == '':
|
||||
tile_descriptions = [tile_descriptions[0],]
|
||||
alignment_constraints = [alignment_constraints[0],]
|
||||
return
|
||||
|
||||
for gemm_kind in gemm_kinds:
|
||||
for layout in layouts:
|
||||
@ -141,6 +139,103 @@ def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_t
|
||||
return
|
||||
|
||||
###########################################################################################################
|
||||
# ConvolutionOperator support variations
|
||||
# ____________________________________________________________________
|
||||
# ConvolutionalOperator | Analytic | Optimized
|
||||
# ____________________________________________________________________
|
||||
# | Fprop | (strided) | (strided)
|
||||
# | Dgrad | (strided, unity*) | (unity)
|
||||
# | Wgrad | (strided) | (strided)
|
||||
# ____________________________________________________________________
|
||||
#
|
||||
# Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
|
||||
###########################################################################################################
|
||||
# Convolution for 2D operations
|
||||
def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
|
||||
conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination):
|
||||
|
||||
element_a, element_b, element_c, element_epilogue = data_type
|
||||
|
||||
# one exceptional case
|
||||
alignment_c = min(8, alignment)
|
||||
|
||||
# iterator algorithm (analytic and optimized)
|
||||
iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
|
||||
|
||||
# by default, only generate the largest tile size
|
||||
if manifest.args.kernels == '':
|
||||
tile_descriptions = [tile_descriptions[0],]
|
||||
|
||||
operations = []
|
||||
|
||||
for tile in tile_descriptions:
|
||||
for conv_kind in conv_kinds:
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
A = TensorDescription(element_a, layout[0], alignment)
|
||||
B = TensorDescription(element_b, layout[1], alignment)
|
||||
C = TensorDescription(element_c, layout[2], alignment_c)
|
||||
|
||||
# unity stride only for Optimized Dgrad
|
||||
if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad):
|
||||
new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
# strided dgrad is not supported by Optimized Dgrad
|
||||
if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad):
|
||||
continue
|
||||
|
||||
# strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic)
|
||||
new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
return operations
|
||||
|
||||
# Convolution for 3D operations
|
||||
def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
|
||||
conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination):
|
||||
|
||||
element_a, element_b, element_c, element_epilogue = data_type
|
||||
|
||||
# one exceptional case
|
||||
alignment_c = min(8, alignment)
|
||||
|
||||
# iterator algorithm (analytic and optimized)
|
||||
iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
|
||||
|
||||
# by default, only generate the largest tile size
|
||||
if manifest.args.kernels == '':
|
||||
tile_descriptions = [tile_descriptions[0],]
|
||||
|
||||
operations = []
|
||||
|
||||
for tile in tile_descriptions:
|
||||
for conv_kind in conv_kinds:
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
A = TensorDescription(element_a, layout, alignment)
|
||||
B = TensorDescription(element_b, layout, alignment)
|
||||
C = TensorDescription(element_c, layout, alignment_c)
|
||||
|
||||
# optimized conv3d iterator algorithm is only for Wgrad
|
||||
if (iterator_algorithm == IteratorAlgorithm.Optimized) \
|
||||
and ((conv_kind == ConvKind.Fprop) or (conv_kind == ConvKind.Dgrad)):
|
||||
continue
|
||||
|
||||
# strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic)
|
||||
new_operation = Conv3dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
|
||||
return operations
|
||||
|
||||
###################################################################################################
|
||||
###################################################################################################
|
||||
|
||||
@ -191,11 +286,57 @@ def GenerateSM50_Simt(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints)
|
||||
|
||||
if math_inst.element_a == DataType.f32:
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
#
|
||||
|
||||
#
|
||||
def GenerateSM50_Simt_complex(manifest, args):
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
math_instructions = [
|
||||
MathInstruction( \
|
||||
[1, 1, 1], \
|
||||
DataType.f32, DataType.f32, DataType.f32, \
|
||||
OpcodeClass.Simt, \
|
||||
MathOperation.multiply_add_complex),
|
||||
]
|
||||
|
||||
min_cc = 50
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = [
|
||||
TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [
|
||||
DataType.cf32,
|
||||
DataType.cf32,
|
||||
DataType.cf32,
|
||||
DataType.cf32,
|
||||
]
|
||||
|
||||
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
#
|
||||
|
||||
#
|
||||
def GenerateSM50(manifest, args):
|
||||
GenerateSM50_Simt(manifest, args)
|
||||
GenerateSM50_Simt_complex(manifest, args)
|
||||
|
||||
###################################################################################################
|
||||
###################################################################################################
|
||||
@ -362,6 +503,9 @@ def GenerateSM70_TensorOp_884(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
|
||||
@ -375,6 +519,8 @@ def GenerateSM70_TensorOp_884(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints)
|
||||
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
|
||||
|
||||
#
|
||||
def GenerateSM70_PlanarComplexTensorOp_884(manifest, args):
|
||||
|
||||
@ -504,50 +650,10 @@ def GenerateSM70_WmmaTensorOp_161616(manifest, args):
|
||||
#
|
||||
##################################################################################################
|
||||
#
|
||||
def GenerateSM70_Simt_complex(manifest, args):
|
||||
math_instructions = [
|
||||
MathInstruction( \
|
||||
[1, 1, 1], \
|
||||
DataType.f32, DataType.f32, DataType.f32, \
|
||||
OpcodeClass.Simt, \
|
||||
MathOperation.multiply_add_complex),
|
||||
]
|
||||
|
||||
min_cc = 70
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = [
|
||||
TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 32, 16], 2, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 64, 16], 2, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 32, 16], 2, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
data_type = [
|
||||
DataType.cf32,
|
||||
DataType.cf32,
|
||||
DataType.cf32,
|
||||
DataType.cf32
|
||||
]
|
||||
|
||||
complex_transforms = [
|
||||
(ComplexTransform.none, ComplexTransform.none),
|
||||
(ComplexTransform.conj, ComplexTransform.none),
|
||||
(ComplexTransform.none, ComplexTransform.conj),
|
||||
(ComplexTransform.conj, ComplexTransform.conj)
|
||||
]
|
||||
|
||||
#
|
||||
|
||||
def GenerateSM70(manifest, args):
|
||||
GenerateSM70_TensorOp_884(manifest, args)
|
||||
GenerateSM70_PlanarComplexTensorOp_884(manifest, args)
|
||||
GenerateSM70_Simt_complex(manifest, args)
|
||||
|
||||
# To limit build size, WMMA GEMMs are disabled for now.
|
||||
#
|
||||
@ -607,6 +713,9 @@ def GenerateSM75_TensorOp_1688(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
|
||||
@ -620,6 +729,8 @@ def GenerateSM75_TensorOp_1688(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints)
|
||||
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
|
||||
|
||||
#
|
||||
|
||||
#
|
||||
@ -738,6 +849,10 @@ def GenerateSM75_TensorOp_8816_TN(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
|
||||
@ -753,6 +868,9 @@ def GenerateSM75_TensorOp_8816_TN(manifest, args):
|
||||
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
op.C.alignment = 16
|
||||
@ -794,6 +912,8 @@ def GenerateSM75_TensorOp_8816_Interleaved(manifest, args):
|
||||
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
@ -809,9 +929,13 @@ def GenerateSM75_TensorOp_8816_Interleaved(manifest, args):
|
||||
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
# conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32)
|
||||
#
|
||||
# operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
# data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
op.C.alignment = 8
|
||||
|
||||
#
|
||||
|
||||
#
|
||||
@ -862,6 +986,10 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
|
||||
@ -877,6 +1005,9 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args):
|
||||
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
op.C.alignment = 8
|
||||
@ -920,9 +1051,9 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, args):
|
||||
TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 128], 2, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 128], 2, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
@ -938,9 +1069,13 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, args):
|
||||
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
# conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64)
|
||||
#
|
||||
# operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
# data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
op.C.alignment = 16
|
||||
|
||||
#
|
||||
|
||||
#
|
||||
@ -1074,6 +1209,8 @@ def GenerateSM75_Simt_complex(manifest, args):
|
||||
(ComplexTransform.conj, ComplexTransform.conj)
|
||||
]
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
#
|
||||
|
||||
def GenerateSM75(manifest, args):
|
||||
@ -1124,6 +1261,7 @@ def GenerateSM80_TensorOp_16816(manifest, args):
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
max_cc_smem_limited = 80
|
||||
|
||||
alignment_constraints = [8, 4, 2]
|
||||
|
||||
@ -1137,10 +1275,10 @@ def GenerateSM80_TensorOp_16816(manifest, args):
|
||||
TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
@ -1157,6 +1295,10 @@ def GenerateSM80_TensorOp_16816(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
|
||||
CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type, 8)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
|
||||
@ -1170,6 +1312,8 @@ def GenerateSM80_TensorOp_16816(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints)
|
||||
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
|
||||
CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type_mixed, 8)
|
||||
#
|
||||
|
||||
#
|
||||
@ -1205,22 +1349,23 @@ def GenerateSM80_SparseTensorOp_16832(manifest, args):
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
max_cc_smem_limited = 80
|
||||
|
||||
alignment_constraints = [8, 4, 2]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = [
|
||||
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
@ -1348,6 +1493,7 @@ def GenerateSM80_TensorOp_16832_TN(manifest, args):
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
max_cc_smem_limited = 80
|
||||
|
||||
alignment_constraints = [16,]
|
||||
|
||||
@ -1361,10 +1507,10 @@ def GenerateSM80_TensorOp_16832_TN(manifest, args):
|
||||
TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
@ -1382,6 +1528,13 @@ def GenerateSM80_TensorOp_16832_TN(manifest, args):
|
||||
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
op.C.alignment = 16
|
||||
@ -1409,21 +1562,22 @@ def GenerateSM80_SparseTensorOp_16864_TN(manifest, args):
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
max_cc_smem_limited = 80
|
||||
|
||||
alignment_constraints = [16,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
@ -1489,10 +1643,14 @@ def GenerateSM80_TensorOp_16832_Interleaved(manifest, args):
|
||||
|
||||
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
|
||||
# conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32)
|
||||
#
|
||||
# operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
# data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
op.C.alignment = 8
|
||||
|
||||
#
|
||||
|
||||
#
|
||||
@ -1520,6 +1678,7 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args):
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
max_cc_smem_limited = 80
|
||||
|
||||
alignment_constraints = [32,]
|
||||
|
||||
@ -1533,14 +1692,14 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args):
|
||||
TileDescription([128, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 128], 10, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
]
|
||||
|
||||
data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32]
|
||||
@ -1582,20 +1741,21 @@ def GenerateSM80_SparseTensorOp_168128_TN(manifest, args):
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
max_cc_smem_limited = 80
|
||||
|
||||
alignment_constraints = [32,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 256], 3, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 256], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 256], 3, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 128, 256], 6, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 128, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 128, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
@ -1655,9 +1815,7 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, args):
|
||||
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 128], 10, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32]
|
||||
@ -1666,7 +1824,12 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, args):
|
||||
|
||||
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
|
||||
# conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64)
|
||||
#
|
||||
# operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
# data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
op.C.alignment = 16
|
||||
#
|
||||
@ -1744,6 +1907,7 @@ def GenerateSM80_TensorOp_1688(manifest, args):
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
max_cc_smem_limited = 80
|
||||
|
||||
alignment_constraints = [4, 2, 1]
|
||||
|
||||
@ -1757,11 +1921,11 @@ def GenerateSM80_TensorOp_1688(manifest, args):
|
||||
TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
@ -1787,6 +1951,10 @@ def GenerateSM80_TensorOp_1688(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4)
|
||||
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 4)
|
||||
#
|
||||
|
||||
#
|
||||
@ -1822,6 +1990,7 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args):
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
max_cc_smem_limited = 80
|
||||
|
||||
alignment_constraints = [4, 2, 1]
|
||||
|
||||
@ -1835,11 +2004,11 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args):
|
||||
TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
@ -1850,6 +2019,8 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4)
|
||||
#
|
||||
|
||||
#
|
||||
@ -1875,22 +2046,23 @@ def GenerateSM80_SparseTensorOp_16816_fast_math(manifest, args):
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
max_cc_smem_limited = 80
|
||||
|
||||
alignment_constraints = [4, 2, 1]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = [
|
||||
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
@ -1971,13 +2143,14 @@ def GenerateSM80_TensorOp_884(manifest, args):
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
max_cc_smem_limited = 80
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
|
||||
TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
@ -2090,7 +2263,7 @@ def GenerateSM80_TensorOp_884_complex_gaussian(manifest, args):
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
def GenerateSM80_Simt(manifest, args):
|
||||
def GenerateSM80_Simt_f32(manifest, args):
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
@ -2136,8 +2309,55 @@ def GenerateSM80_Simt(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
#
|
||||
|
||||
|
||||
#
|
||||
def GenerateSM80_Simt_f64(manifest, args):
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
math_instructions = [
|
||||
MathInstruction( \
|
||||
[1, 1, 1], \
|
||||
DataType.f64, DataType.f64, DataType.f64, \
|
||||
OpcodeClass.Simt, \
|
||||
MathOperation.multiply_add),
|
||||
]
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = [
|
||||
TileDescription([128, 128, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 32, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 32, 128, 8], 5, [1, 2, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
|
||||
data_type = [
|
||||
math_inst.element_a,
|
||||
math_inst.element_b,
|
||||
math_inst.element_accumulator,
|
||||
math_inst.element_accumulator,
|
||||
]
|
||||
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
|
||||
##################################################################################################
|
||||
#
|
||||
def GenerateSM80_Simt_complex(manifest, args):
|
||||
@ -2154,7 +2374,29 @@ def GenerateSM80_Simt_complex(manifest, args):
|
||||
|
||||
alignment_constraints = [1,]
|
||||
|
||||
data_type = [
|
||||
DataType.cf32,
|
||||
DataType.cf32,
|
||||
DataType.cf32,
|
||||
DataType.cf32
|
||||
]
|
||||
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
complex_transforms = [
|
||||
(ComplexTransform.none, ComplexTransform.none),
|
||||
(ComplexTransform.conj, ComplexTransform.none),
|
||||
(ComplexTransform.none, ComplexTransform.conj),
|
||||
(ComplexTransform.conj, ComplexTransform.conj)
|
||||
]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
|
||||
tile_descriptions = [
|
||||
TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc),
|
||||
@ -2165,20 +2407,11 @@ def GenerateSM80_Simt_complex(manifest, args):
|
||||
TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
]
|
||||
data_type = [
|
||||
DataType.cf32,
|
||||
DataType.cf32,
|
||||
DataType.cf32,
|
||||
DataType.cf32
|
||||
]
|
||||
|
||||
complex_transforms = [
|
||||
(ComplexTransform.none, ComplexTransform.none),
|
||||
(ComplexTransform.conj, ComplexTransform.none),
|
||||
(ComplexTransform.none, ComplexTransform.conj),
|
||||
(ComplexTransform.conj, ComplexTransform.conj)
|
||||
]
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints, complex_transforms)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
#
|
||||
|
||||
###################################################################################################
|
||||
@ -2202,7 +2435,8 @@ def GenerateSM80(manifest, args):
|
||||
GenerateSM80_SparseTensorOp_168128_TN(manifest, args)
|
||||
GenerateSM80_TensorOp_16864_Interleaved(manifest, args)
|
||||
GenerateSM80_TensorOp_168256(manifest, args)
|
||||
GenerateSM80_Simt(manifest, args)
|
||||
GenerateSM80_Simt_f32(manifest, args)
|
||||
GenerateSM80_Simt_f64(manifest, args)
|
||||
GenerateSM80_Simt_complex(manifest, args)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
Reference in New Issue
Block a user