CUTLASS 2.10 (#615)

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
ANIKET SHIVAM
2022-09-03 15:48:46 -07:00
committed by GitHub
parent ca23ff7924
commit b72cbf957d
289 changed files with 43708 additions and 2513 deletions

View File

@ -149,6 +149,35 @@ class GemmOperation:
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
return self.procedural_name()
###################################################################################################
#
# Data structure modeling a grouped GEMM operation
#
###################################################################################################
#
class GroupedGemmOperation(GemmOperation):
#
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
scheduler_mode = GroupScheduleMode.Device):
super().__init__(gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
epilogue_functor, swizzling_functor)
self.scheduler_mode = scheduler_mode
#
def procedural_name(self):
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
base = super().procedural_name()
return SubstituteTemplate(
base + "_schedule${schedule}",
{
'schedule': ShortGroupScheduleModeNames[self.scheduler_mode]
})
###################################################################################################
#
# Emits single instances of a CUTLASS device-wide operator
@ -738,6 +767,7 @@ using ${operation_name}_base =
${epilogue_functor},
${swizzling_functor},
${stages},
${scheduler_mode},
${math_operation}
>::GemmKernel;
@ -817,6 +847,7 @@ ${compile_guard_end}
'align_b': str(operation.B.alignment),
'transform_a': ComplexTransformTag[operation.A.complex_transform],
'transform_b': ComplexTransformTag[operation.B.complex_transform],
'scheduler_mode': GroupScheduleModeTag[operation.scheduler_mode],
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
}