refine the implementation

This commit is contained in:
Haicheng Wu
2021-09-08 13:14:08 +00:00
parent 4e8af93da1
commit 59e2aa505a
45 changed files with 1593 additions and 706 deletions

View File

@ -103,9 +103,9 @@ class Conv2dOperation:
)
if self.stride_support == StrideSupport.Unity:
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride"
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}_unity_stride"
else:
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}"
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
return SubstituteTemplate(
configuration_name,
@ -114,6 +114,7 @@ class Conv2dOperation:
'extended_name': self.extended_name(),
'threadblock': threadblock,
'layout': self.layout_name(),
'alignment': "%d" % self.A.alignment,
}
)
@ -156,7 +157,9 @@ class EmitConv2dInstance:
${stages},
${math_operator},
${iterator_algorithm},
${stride_support}
${stride_support},
${align_a},
${align_b}
>::Kernel;
"""
@ -198,7 +201,9 @@ class EmitConv2dInstance:
'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
'stride_support': StrideSupportTag[operation.stride_support],
'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \
MathOperationTag[operation.tile_description.math_instruction.math_operation]
MathOperationTag[operation.tile_description.math_instruction.math_operation],
'align_a': str(operation.A.alignment),
'align_b': str(operation.B.alignment),
}
return SubstituteTemplate(self.template, values)
@ -341,4 +346,3 @@ void initialize_${configuration_name}(Manifest &manifest) {
###################################################################################################
###################################################################################################

View File

@ -151,14 +151,13 @@ def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_t
# Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
###########################################################################################################
# Convolution for 2D operations
def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \
conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
element_a, element_b, element_c, element_epilogue = data_type
# one exceptional case
alignment_c = min(8, alignment)
# iterator algorithm (analytic and optimized)
iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
@ -166,66 +165,71 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme
# by default, only generate the largest tile size
if manifest.args.kernels == '':
tile_descriptions = [tile_descriptions[0],]
alignment_constraints = [alignment_constraints[0],]
operations = []
for tile in tile_descriptions:
A = TensorDescription(element_a, layout[0], alignment)
B = TensorDescription(element_b, layout[1], alignment)
C = TensorDescription(element_c, layout[2], alignment_c)
swizzling_functor_ = swizzling_functor
for alignment in alignment_constraints:
#
# Conv2d Fprop
#
if ConvKind.Fprop in conv_kinds:
alignment_c = min(8, alignment)
# Strided support for Analytic and Optimized Fprop
for iterator_algorithm in iterator_algorithms:
new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
manifest.append(new_operation)
operations.append(new_operation)
#
# Conv2d Dgrad
#
if ConvKind.Dgrad in conv_kinds:
# Unity stride for Analytic and Optimized Dgrad
for iterator_algorithm in iterator_algorithms:
new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_)
manifest.append(new_operation)
operations.append(new_operation)
# Strided support for Analytic Dgrad
# strided dgrad uses a special threadblock swizzle
# note that SwizzlingFunctor.StridedDgradHorizontal might be
# better for problem sizes with large activation channel count
swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1
new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_)
manifest.append(new_operation)
operations.append(new_operation)
A = TensorDescription(element_a, layout[0], alignment)
B = TensorDescription(element_b, layout[1], alignment)
C = TensorDescription(element_c, layout[2], alignment_c)
#
# Conv2d Wgrad
#
if ConvKind.Wgrad in conv_kinds:
# Strided support for Analytic and Optimized Wgrad
for iterator_algorithm in iterator_algorithms:
new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
swizzling_functor_ = swizzling_functor
#
# Conv2d Fprop
#
if ConvKind.Fprop in conv_kinds:
# Strided support for Analytic and Optimized Fprop
for iterator_algorithm in iterator_algorithms:
new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
manifest.append(new_operation)
operations.append(new_operation)
#
# Conv2d Dgrad
#
if ConvKind.Dgrad in conv_kinds:
# Unity stride for Analytic and Optimized Dgrad
for iterator_algorithm in iterator_algorithms:
new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_)
manifest.append(new_operation)
operations.append(new_operation)
# Strided support for Analytic Dgrad
# strided dgrad uses a special threadblock swizzle
# note that SwizzlingFunctor.StridedDgradHorizontal might be
# better for problem sizes with large activation channel count
swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1
new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_)
manifest.append(new_operation)
operations.append(new_operation)
#
# Conv2d Wgrad
#
if ConvKind.Wgrad in conv_kinds:
# Strided support for Analytic and Optimized Wgrad
for iterator_algorithm in iterator_algorithms:
new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
manifest.append(new_operation)
operations.append(new_operation)
return operations
@ -322,7 +326,7 @@ def GenerateSM50_Simt(manifest, args):
if math_inst.element_a == DataType.f32:
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
#
#
@ -369,7 +373,7 @@ def GenerateSM50_Simt_complex(manifest, args):
data_type, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
#
#
@ -543,7 +547,7 @@ def GenerateSM70_TensorOp_884(manifest, args):
data_type, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
if math_inst.element_a != math_inst.element_accumulator:
@ -558,7 +562,7 @@ def GenerateSM70_TensorOp_884(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
#
def GenerateSM70_PlanarComplexTensorOp_884(manifest, args):
@ -754,7 +758,7 @@ def GenerateSM75_TensorOp_1688(manifest, args):
data_type, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
if math_inst.element_a != math_inst.element_accumulator:
@ -769,7 +773,7 @@ def GenerateSM75_TensorOp_1688(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
#
@ -891,7 +895,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, args):
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
if math_inst.element_a != math_inst.element_accumulator:
@ -909,7 +913,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, args):
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
@ -972,7 +976,7 @@ def GenerateSM75_TensorOp_8816_Interleaved(manifest, args):
conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32)
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
op.C.alignment = 8
@ -1028,7 +1032,7 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args):
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
if math_inst.element_a != math_inst.element_accumulator:
@ -1046,7 +1050,7 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args):
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
@ -1112,7 +1116,7 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, args):
conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64)
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
op.C.alignment = 16
@ -1250,7 +1254,7 @@ def GenerateSM75_Simt_complex(manifest, args):
]
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
#
def GenerateSM75(manifest, args):
@ -1338,7 +1342,7 @@ def GenerateSM80_TensorOp_16816(manifest, args):
data_type, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type, 8)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
@ -1354,7 +1358,7 @@ def GenerateSM80_TensorOp_16816(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type_mixed, 8)
#
@ -1572,10 +1576,10 @@ def GenerateSM80_TensorOp_16832_TN(manifest, args):
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
@ -1689,7 +1693,7 @@ def GenerateSM80_TensorOp_16832_Interleaved(manifest, args):
conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32)
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
op.C.alignment = 8
@ -1758,10 +1762,10 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args):
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
@ -1878,7 +1882,7 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, args):
conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64)
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
op.C.alignment = 16
@ -2005,9 +2009,9 @@ def GenerateSM80_TensorOp_1688(manifest, args):
data_type_mixed, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 4)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
#
#
@ -2076,7 +2080,7 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args):
data_type, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
#
#
@ -2366,7 +2370,7 @@ def GenerateSM80_Simt_f32(manifest, args):
data_type, alignment_constraints)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
#
@ -2467,7 +2471,7 @@ def GenerateSM80_Simt_complex(manifest, args):
CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints, complex_transforms)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
#
###################################################################################################