Updates for CUTLASS 3.5.0 (#1468)
This commit is contained in:
@ -205,7 +205,7 @@ class GemmOperation:
|
||||
extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
||||
element_a = DataTypeNames[self.A.element],
|
||||
element_b = DataTypeNames[self.B.element],
|
||||
element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator],
|
||||
element_acc = DataTypeNames[self.accumulator_type()],
|
||||
element_c = DataTypeNames[self.C.element],
|
||||
element_d = DataTypeNames[self.D.element],
|
||||
core_name = self.core_name())
|
||||
@ -216,7 +216,7 @@ class GemmOperation:
|
||||
datatype_name = "{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
||||
element_a = DataTypeNames[self.A.element],
|
||||
element_b = DataTypeNames[self.B.element],
|
||||
element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator],
|
||||
element_acc = DataTypeNames[self.accumulator_type()],
|
||||
element_c = DataTypeNames[self.C.element],
|
||||
element_d = DataTypeNames[self.D.element])
|
||||
return datatype_name
|
||||
@ -744,7 +744,7 @@ using ${operation_name}_mainloop =
|
||||
cute::Shape<cute::_${tile_shape_main_m}, cute::_${tile_shape_main_n}, cute::_${tile_shape_main_k}>,
|
||||
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
|
||||
${stages},
|
||||
${kernel_schedule}
|
||||
${kernel_schedule}
|
||||
>::CollectiveOp;
|
||||
|
||||
// Gemm operator ${operation_name}
|
||||
@ -817,8 +817,9 @@ ${compile_guard_end}
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor.emit_declaration()
|
||||
#
|
||||
element_a = DataTypeTag[operation.A.element]
|
||||
element_b = DataTypeTag[operation.B.element]
|
||||
# Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
|
||||
element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
|
||||
element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
|
||||
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
|
||||
@ -967,6 +967,7 @@ class ConvOperation3x:
|
||||
|
||||
def configuration_name(self):
|
||||
prefix = 'cutlass3x'
|
||||
arch = self.arch
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
tbm = self.tile_description.tile_shape[0]
|
||||
tbn = self.tile_description.tile_shape[1]
|
||||
@ -979,7 +980,7 @@ class ConvOperation3x:
|
||||
kernel_schedule = KernelScheduleSuffixes[self.kernel_schedule]
|
||||
epilogue_schedule = EpilogueScheduleSuffixes[self.epilogue_schedule]
|
||||
|
||||
return f"{prefix}_{opcode_class_name}_{self.extended_name()}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{self.tile_description.stages}_align{alignment}{tile_scheduler}{kernel_schedule}{epilogue_schedule}"
|
||||
return f"{prefix}_sm{arch}_{opcode_class_name}_{self.extended_name()}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{self.tile_description.stages}_align{alignment}{tile_scheduler}{kernel_schedule}{epilogue_schedule}"
|
||||
|
||||
def procedural_name(self):
|
||||
return self.configuration_name()
|
||||
|
||||
@ -250,6 +250,12 @@ ComplexTransformTag = {
|
||||
ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
|
||||
}
|
||||
|
||||
# Used for cutlass3x complex kernel collective mainloop builder instantiation
|
||||
ComplexTransformTag3x = {
|
||||
ComplexTransform.none: 'cute::identity',
|
||||
ComplexTransform.conj: 'cute::conjugate',
|
||||
}
|
||||
|
||||
#
|
||||
RealComplexBijection = [
|
||||
(DataType.f16, DataType.cf16),
|
||||
|
||||
Reference in New Issue
Block a user