Updates for CUTLASS 3.5.0 (#1468)

This commit is contained in:
Vijay Thakkar
2024-04-11 21:33:40 -04:00
committed by GitHub
parent a40e08e9d5
commit 7d49e6c7e2
171 changed files with 7526 additions and 1888 deletions

View File

@ -205,7 +205,7 @@ class GemmOperation:
extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
element_a = DataTypeNames[self.A.element],
element_b = DataTypeNames[self.B.element],
element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator],
element_acc = DataTypeNames[self.accumulator_type()],
element_c = DataTypeNames[self.C.element],
element_d = DataTypeNames[self.D.element],
core_name = self.core_name())
@ -216,7 +216,7 @@ class GemmOperation:
datatype_name = "{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
element_a = DataTypeNames[self.A.element],
element_b = DataTypeNames[self.B.element],
element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator],
element_acc = DataTypeNames[self.accumulator_type()],
element_c = DataTypeNames[self.C.element],
element_d = DataTypeNames[self.D.element])
return datatype_name
@ -744,7 +744,7 @@ using ${operation_name}_mainloop =
cute::Shape<cute::_${tile_shape_main_m}, cute::_${tile_shape_main_n}, cute::_${tile_shape_main_k}>,
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
${stages},
${kernel_schedule}
${kernel_schedule}
>::CollectiveOp;
// Gemm operator ${operation_name}
@ -817,8 +817,9 @@ ${compile_guard_end}
else:
epilogue_functor = self.epilogue_functor.emit_declaration()
#
element_a = DataTypeTag[operation.A.element]
element_b = DataTypeTag[operation.B.element]
# Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
values = {
'operation_name': operation.procedural_name(),

View File

@ -967,6 +967,7 @@ class ConvOperation3x:
def configuration_name(self):
prefix = 'cutlass3x'
arch = self.arch
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
tbm = self.tile_description.tile_shape[0]
tbn = self.tile_description.tile_shape[1]
@ -979,7 +980,7 @@ class ConvOperation3x:
kernel_schedule = KernelScheduleSuffixes[self.kernel_schedule]
epilogue_schedule = EpilogueScheduleSuffixes[self.epilogue_schedule]
return f"{prefix}_{opcode_class_name}_{self.extended_name()}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{self.tile_description.stages}_align{alignment}{tile_scheduler}{kernel_schedule}{epilogue_schedule}"
return f"{prefix}_sm{arch}_{opcode_class_name}_{self.extended_name()}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{self.tile_description.stages}_align{alignment}{tile_scheduler}{kernel_schedule}{epilogue_schedule}"
def procedural_name(self):
return self.configuration_name()

View File

@ -250,6 +250,12 @@ ComplexTransformTag = {
ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
}
# Used for cutlass3x complex kernel collective mainloop builder instantiation
ComplexTransformTag3x = {
ComplexTransform.none: 'cute::identity',
ComplexTransform.conj: 'cute::conjugate',
}
#
RealComplexBijection = [
(DataType.f16, DataType.cf16),