refine the implementation
This commit is contained in:
@ -103,9 +103,9 @@ class Conv2dOperation:
|
||||
)
|
||||
|
||||
if self.stride_support == StrideSupport.Unity:
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride"
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}_unity_stride"
|
||||
else:
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}"
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"
|
||||
|
||||
return SubstituteTemplate(
|
||||
configuration_name,
|
||||
@ -114,6 +114,7 @@ class Conv2dOperation:
|
||||
'extended_name': self.extended_name(),
|
||||
'threadblock': threadblock,
|
||||
'layout': self.layout_name(),
|
||||
'alignment': "%d" % self.A.alignment,
|
||||
}
|
||||
)
|
||||
|
||||
@ -156,7 +157,9 @@ class EmitConv2dInstance:
|
||||
${stages},
|
||||
${math_operator},
|
||||
${iterator_algorithm},
|
||||
${stride_support}
|
||||
${stride_support},
|
||||
${align_a},
|
||||
${align_b}
|
||||
>::Kernel;
|
||||
"""
|
||||
|
||||
@ -198,7 +201,9 @@ class EmitConv2dInstance:
|
||||
'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
|
||||
'stride_support': StrideSupportTag[operation.stride_support],
|
||||
'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \
|
||||
MathOperationTag[operation.tile_description.math_instruction.math_operation]
|
||||
MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
||||
'align_a': str(operation.A.alignment),
|
||||
'align_b': str(operation.B.alignment),
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.template, values)
|
||||
@ -341,4 +346,3 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
|
||||
###################################################################################################
|
||||
###################################################################################################
|
||||
|
||||
|
||||
@ -151,14 +151,13 @@ def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_t
|
||||
# Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
|
||||
###########################################################################################################
|
||||
# Convolution for 2D operations
|
||||
def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
|
||||
def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \
|
||||
conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
|
||||
|
||||
element_a, element_b, element_c, element_epilogue = data_type
|
||||
|
||||
# one exceptional case
|
||||
alignment_c = min(8, alignment)
|
||||
|
||||
# iterator algorithm (analytic and optimized)
|
||||
iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
|
||||
@ -166,66 +165,71 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme
|
||||
# by default, only generate the largest tile size
|
||||
if manifest.args.kernels == '':
|
||||
tile_descriptions = [tile_descriptions[0],]
|
||||
alignment_constraints = [alignment_constraints[0],]
|
||||
|
||||
operations = []
|
||||
|
||||
for tile in tile_descriptions:
|
||||
A = TensorDescription(element_a, layout[0], alignment)
|
||||
B = TensorDescription(element_b, layout[1], alignment)
|
||||
C = TensorDescription(element_c, layout[2], alignment_c)
|
||||
|
||||
swizzling_functor_ = swizzling_functor
|
||||
for alignment in alignment_constraints:
|
||||
|
||||
#
|
||||
# Conv2d Fprop
|
||||
#
|
||||
if ConvKind.Fprop in conv_kinds:
|
||||
alignment_c = min(8, alignment)
|
||||
|
||||
# Strided support for Analytic and Optimized Fprop
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
#
|
||||
# Conv2d Dgrad
|
||||
#
|
||||
if ConvKind.Dgrad in conv_kinds:
|
||||
|
||||
# Unity stride for Analytic and Optimized Dgrad
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
# Strided support for Analytic Dgrad
|
||||
# strided dgrad uses a special threadblock swizzle
|
||||
# note that SwizzlingFunctor.StridedDgradHorizontal might be
|
||||
# better for problem sizes with large activation channel count
|
||||
swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1
|
||||
|
||||
new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
A = TensorDescription(element_a, layout[0], alignment)
|
||||
B = TensorDescription(element_b, layout[1], alignment)
|
||||
C = TensorDescription(element_c, layout[2], alignment_c)
|
||||
|
||||
#
|
||||
# Conv2d Wgrad
|
||||
#
|
||||
if ConvKind.Wgrad in conv_kinds:
|
||||
|
||||
# Strided support for Analytic and Optimized Wgrad
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
|
||||
|
||||
swizzling_functor_ = swizzling_functor
|
||||
|
||||
#
|
||||
# Conv2d Fprop
|
||||
#
|
||||
if ConvKind.Fprop in conv_kinds:
|
||||
|
||||
# Strided support for Analytic and Optimized Fprop
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
#
|
||||
# Conv2d Dgrad
|
||||
#
|
||||
if ConvKind.Dgrad in conv_kinds:
|
||||
|
||||
# Unity stride for Analytic and Optimized Dgrad
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
# Strided support for Analytic Dgrad
|
||||
# strided dgrad uses a special threadblock swizzle
|
||||
# note that SwizzlingFunctor.StridedDgradHorizontal might be
|
||||
# better for problem sizes with large activation channel count
|
||||
swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1
|
||||
|
||||
new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
#
|
||||
# Conv2d Wgrad
|
||||
#
|
||||
if ConvKind.Wgrad in conv_kinds:
|
||||
|
||||
# Strided support for Analytic and Optimized Wgrad
|
||||
for iterator_algorithm in iterator_algorithms:
|
||||
new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_)
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
return operations
|
||||
|
||||
@ -322,7 +326,7 @@ def GenerateSM50_Simt(manifest, args):
|
||||
|
||||
if math_inst.element_a == DataType.f32:
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
#
|
||||
@ -369,7 +373,7 @@ def GenerateSM50_Simt_complex(manifest, args):
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
#
|
||||
@ -543,7 +547,7 @@ def GenerateSM70_TensorOp_884(manifest, args):
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
@ -558,7 +562,7 @@ def GenerateSM70_TensorOp_884(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints)
|
||||
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
|
||||
|
||||
#
|
||||
def GenerateSM70_PlanarComplexTensorOp_884(manifest, args):
|
||||
@ -754,7 +758,7 @@ def GenerateSM75_TensorOp_1688(manifest, args):
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
@ -769,7 +773,7 @@ def GenerateSM75_TensorOp_1688(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints)
|
||||
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
|
||||
|
||||
#
|
||||
|
||||
@ -891,7 +895,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, args):
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
@ -909,7 +913,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, args):
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
@ -972,7 +976,7 @@ def GenerateSM75_TensorOp_8816_Interleaved(manifest, args):
|
||||
conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
op.C.alignment = 8
|
||||
@ -1028,7 +1032,7 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args):
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
if math_inst.element_a != math_inst.element_accumulator:
|
||||
@ -1046,7 +1050,7 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args):
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
@ -1112,7 +1116,7 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, args):
|
||||
conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
op.C.alignment = 16
|
||||
@ -1250,7 +1254,7 @@ def GenerateSM75_Simt_complex(manifest, args):
|
||||
]
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
def GenerateSM75(manifest, args):
|
||||
@ -1338,7 +1342,7 @@ def GenerateSM80_TensorOp_16816(manifest, args):
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 8)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type, 8)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
@ -1354,7 +1358,7 @@ def GenerateSM80_TensorOp_16816(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints)
|
||||
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 8)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
|
||||
CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type_mixed, 8)
|
||||
#
|
||||
|
||||
@ -1572,10 +1576,10 @@ def GenerateSM80_TensorOp_16832_TN(manifest, args):
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
@ -1689,7 +1693,7 @@ def GenerateSM80_TensorOp_16832_Interleaved(manifest, args):
|
||||
conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 16, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
op.C.alignment = 8
|
||||
@ -1758,10 +1762,10 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args):
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
@ -1878,7 +1882,7 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, args):
|
||||
conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
op.C.alignment = 16
|
||||
@ -2005,9 +2009,9 @@ def GenerateSM80_TensorOp_1688(manifest, args):
|
||||
data_type_mixed, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 4)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints)
|
||||
#
|
||||
|
||||
#
|
||||
@ -2076,7 +2080,7 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args):
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
#
|
||||
@ -2366,7 +2370,7 @@ def GenerateSM80_Simt_f32(manifest, args):
|
||||
data_type, alignment_constraints)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
|
||||
@ -2467,7 +2471,7 @@ def GenerateSM80_Simt_complex(manifest, args):
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints, complex_transforms)
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 1)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints)
|
||||
#
|
||||
|
||||
###################################################################################################
|
||||
|
||||
Reference in New Issue
Block a user