update 3.8 v2 (#2112)
* update 3.8 v2 * update 3.8 --------- Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
@ -69,7 +69,7 @@ using ${operation_name}_epilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
${arch},
|
||||
${opcode_class_epi},
|
||||
${output_cta_tile_shape}, // output cta tile shape
|
||||
${mma_tile_shape}, // mma tile shape
|
||||
${cluster_shape}, // cluster shape
|
||||
${epi_tile_mn},
|
||||
${element_accumulator},
|
||||
@ -109,26 +109,6 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
|
||||
def arch_number_to_type(self, arch: int) -> str:
|
||||
return f"cutlass::arch::Sm{arch}"
|
||||
|
||||
def output_cta_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
|
||||
# For all three kinds of convolutions, the tile shape's K mode
|
||||
# differs from GEMM in that needs to be wrapped in a Shape.
|
||||
# For Wgrad convolutions specifically,
|
||||
# the N tile shape also needs to be wrapped in a Shape.
|
||||
m_template = 'cute::_${cta_m}'
|
||||
if operation.conv_kind == ConvKind.Wgrad:
|
||||
n_template = 'cute::Shape<cute::_${cta_n}>'
|
||||
else:
|
||||
n_template = 'cute::_${cta_n}'
|
||||
k_template = 'cute::Shape<cute::_${cta_k}>'
|
||||
|
||||
output_cta_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
|
||||
values = {
|
||||
'cta_m': cta_m,
|
||||
'cta_n': cta_n,
|
||||
'cta_k': cta_k
|
||||
}
|
||||
return Template(output_cta_tile_shape_template).substitute(values)
|
||||
|
||||
def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
|
||||
mma_m = cta_m
|
||||
mma_n = cta_n
|
||||
@ -223,7 +203,6 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': opcode_class,
|
||||
'arch': self.arch_number_to_type(operation.arch),
|
||||
'output_cta_tile_shape': self.output_cta_tile_shape(operation, cta_m, cta_n, cta_k),
|
||||
'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k),
|
||||
'cluster_shape': self.cluster_shape(operation),
|
||||
'opcode_class_epi': opcode_class_epi,
|
||||
|
||||
Reference in New Issue
Block a user