update 3.8 v2 (#2112)

* update 3.8 v2

* update 3.8

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-02-19 19:03:14 -08:00
committed by GitHub
parent e9627ce55b
commit b84e9802d8
166 changed files with 3986 additions and 4037 deletions

View File

@ -69,7 +69,7 @@ using ${operation_name}_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
${arch},
${opcode_class_epi},
${output_cta_tile_shape}, // output cta tile shape
${mma_tile_shape}, // mma tile shape
${cluster_shape}, // cluster shape
${epi_tile_mn},
${element_accumulator},
@ -109,26 +109,6 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
def arch_number_to_type(self, arch: int) -> str:
return f"cutlass::arch::Sm{arch}"
def output_cta_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
# For all three kinds of convolutions, the tile shape's K mode
# differs from GEMM in that needs to be wrapped in a Shape.
# For Wgrad convolutions specifically,
# the N tile shape also needs to be wrapped in a Shape.
m_template = 'cute::_${cta_m}'
if operation.conv_kind == ConvKind.Wgrad:
n_template = 'cute::Shape<cute::_${cta_n}>'
else:
n_template = 'cute::_${cta_n}'
k_template = 'cute::Shape<cute::_${cta_k}>'
output_cta_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
values = {
'cta_m': cta_m,
'cta_n': cta_n,
'cta_k': cta_k
}
return Template(output_cta_tile_shape_template).substitute(values)
def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
mma_m = cta_m
mma_n = cta_n
@ -223,7 +203,6 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
'element_accumulator': DataTypeTag[operation.accumulator_type()],
'opcode_class': opcode_class,
'arch': self.arch_number_to_type(operation.arch),
'output_cta_tile_shape': self.output_cta_tile_shape(operation, cta_m, cta_n, cta_k),
'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k),
'cluster_shape': self.cluster_shape(operation),
'opcode_class_epi': opcode_class_epi,