CUTLASS 2.6 (#298)

CUTLASS 2.6
This commit is contained in:
Manish Gupta
2021-07-22 21:40:53 -07:00
committed by GitHub
parent 6c29fe20ba
commit e5d51840e8
308 changed files with 32408 additions and 4722 deletions

View File

@ -141,18 +141,19 @@ def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_t
###########################################################################################################
# ConvolutionOperator support variations
# ____________________________________________________________________
# ConvolutionalOperator | Analytic | Optimized
# ConvolutionalOperator | Analytic | Optimized
# ____________________________________________________________________
# | Fprop | (strided) | (strided)
# | Dgrad | (strided, unity*) | (unity)
# | Wgrad | (strided) | (strided)
# | Fprop | (strided) | (strided)
# | Dgrad | (strided, unity*) | (strided, unity)
# | Wgrad | (strided) | (strided)
# ____________________________________________________________________
#
# Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
###########################################################################################################
# Convolution for 2D operations
def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination):
conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
element_a, element_b, element_c, element_epilogue = data_type
@ -169,33 +170,66 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme
operations = []
for tile in tile_descriptions:
for conv_kind in conv_kinds:
A = TensorDescription(element_a, layout[0], alignment)
B = TensorDescription(element_b, layout[1], alignment)
C = TensorDescription(element_c, layout[2], alignment_c)
swizzling_functor_ = swizzling_functor
#
# Conv2d Fprop
#
if ConvKind.Fprop in conv_kinds:
# Strided support for Analytic and Optimized Fprop
for iterator_algorithm in iterator_algorithms:
A = TensorDescription(element_a, layout[0], alignment)
B = TensorDescription(element_b, layout[1], alignment)
C = TensorDescription(element_c, layout[2], alignment_c)
new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
# unity stride only for Optimized Dgrad
if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad):
new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor)
manifest.append(new_operation)
operations.append(new_operation)
manifest.append(new_operation)
operations.append(new_operation)
#
# Conv2d Dgrad
#
if ConvKind.Dgrad in conv_kinds:
# strided dgrad is not supported by Optimized Dgrad
if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad):
continue
# Unity stride for Analytic and Optimized Dgrad
for iterator_algorithm in iterator_algorithms:
new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_)
# strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic)
new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor)
manifest.append(new_operation)
operations.append(new_operation)
# Strided support for Analytic Dgrad
# strided dgrad uses a special threadblock swizzle
# note that SwizzlingFunctor.StridedDgradHorizontal might be
# better for problem sizes with large activation channel count
swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1
new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_)
manifest.append(new_operation)
operations.append(new_operation)
#
# Conv2d Wgrad
#
if ConvKind.Wgrad in conv_kinds:
# Strided support for Analytic and Optimized Wgrad
for iterator_algorithm in iterator_algorithms:
new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
manifest.append(new_operation)
operations.append(new_operation)
return operations
# Convolution for 3D operations
def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination):
@ -315,6 +349,11 @@ def GenerateSM50_Simt_complex(manifest, args):
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
]
@ -1272,6 +1311,7 @@ def GenerateSM80_TensorOp_16816(manifest, args):
TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
@ -1698,9 +1738,10 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args):
TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32]
@ -1713,14 +1754,14 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args):
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
op.C.alignment = 8
@ -1934,6 +1975,7 @@ def GenerateSM80_TensorOp_1688(manifest, args):
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
@ -1993,7 +2035,7 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args):
[16, 8, 8], \
DataType.bf16, DataType.bf16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_fast_bf16)
MathOperation.multiply_add_fast_bf16),
]
min_cc = 80
@ -2017,6 +2059,7 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args):
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
@ -2031,6 +2074,7 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args):
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4)
#
#
#
def GenerateSM80_SparseTensorOp_16816_fast_math(manifest, args):
@ -2155,9 +2199,9 @@ def GenerateSM80_TensorOp_884(manifest, args):
alignment_constraints = [1,]
tile_descriptions = [
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited),
TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
@ -2463,6 +2507,7 @@ if __name__ == "__main__":
parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file')
parser.add_argument('--selected-kernel-list', type=str, default=None, required=False,
help='Specify the output log file containing all enabled kernels in this build')
parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels")
args = parser.parse_args()