CUTLASS 2.9 (#468)

This commit is contained in:
Andrew Kerr
2022-04-23 15:02:38 -04:00
committed by GitHub
parent dd571f0edb
commit 12f4108ac2
1100 changed files with 94818 additions and 20385 deletions

View File

@ -187,6 +187,17 @@ DataTypeSize = {
}
###################################################################################################
#
class BlasMode(enum.Enum):
symmetric = enum_auto()
hermitian = enum_auto()
#
BlasModeTag = {
BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric',
BlasMode.hermitian: 'cutlass::BlasMode::kHermitian',
}
#
class ComplexTransform(enum.Enum):
none = enum_auto()
@ -341,6 +352,64 @@ ShortComplexLayoutNames = {
}
###################################################################################################
#
class SideMode(enum.Enum):
Left = enum_auto()
Right = enum_auto()
#
SideModeTag = {
SideMode.Left: 'cutlass::SideMode::kLeft',
SideMode.Right: 'cutlass::SideMode::kRight'
}
#
ShortSideModeNames = {
SideMode.Left: 'ls',
SideMode.Right: 'rs'
}
###################################################################################################
#
class FillMode(enum.Enum):
Lower = enum_auto()
Upper = enum_auto()
#
FillModeTag = {
FillMode.Lower: 'cutlass::FillMode::kLower',
FillMode.Upper: 'cutlass::FillMode::kUpper'
}
#
ShortFillModeNames = {
FillMode.Lower: 'l',
FillMode.Upper: 'u'
}
###################################################################################################
#
class DiagType(enum.Enum):
NonUnit = enum_auto()
Unit = enum_auto()
#
DiagTypeTag = {
DiagType.NonUnit: 'cutlass::DiagType::kNonUnit',
DiagType.Unit: 'cutlass::DiagType::kUnit'
}
#
ShortDiagTypeNames = {
DiagType.NonUnit: 'nu',
DiagType.Unit: 'un'
}
###################################################################################################
#
class OpcodeClass(enum.Enum):
Simt = enum_auto()
@ -366,12 +435,20 @@ OpcodeClassTag = {
#
class OperationKind(enum.Enum):
Gemm = enum_auto()
RankK = enum_auto()
Rank2K = enum_auto()
Trmm = enum_auto()
Symm = enum_auto()
Conv2d = enum_auto()
Conv3d = enum_auto()
#
OperationKindNames = {
OperationKind.Gemm: 'gemm'
, OperationKind.RankK: 'rank_k'
, OperationKind.Rank2K: 'rank_2k'
, OperationKind.Trmm: 'trmm'
, OperationKind.Symm: 'symm'
, OperationKind.Conv2d: 'conv2d'
, OperationKind.Conv3d: 'conv3d'
}
@ -414,6 +491,7 @@ class GemmKind(enum.Enum):
Universal = enum_auto()
PlanarComplex = enum_auto()
PlanarComplexArray = enum_auto()
Grouped = enum_auto()
#
GemmKindNames = {
@ -422,6 +500,34 @@ GemmKindNames = {
GemmKind.Universal: "gemm",
GemmKind.PlanarComplex: "gemm_planar_complex",
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
GemmKind.Grouped: "gemm_grouped"
}
#
class RankKKind(enum.Enum):
Universal = enum_auto()
#
RankKKindNames = {
RankKKind.Universal: "rank_k"
}
#
class TrmmKind(enum.Enum):
Universal = enum_auto()
#
TrmmKindNames = {
TrmmKind.Universal: "trmm"
}
#
class SymmKind(enum.Enum):
Universal = enum_auto()
#
SymmKindNames = {
SymmKind.Universal: "symm"
}
#
@ -483,16 +589,22 @@ ConvKindNames = {
class IteratorAlgorithm(enum.Enum):
Analytic = enum_auto()
Optimized = enum_auto()
FixedChannels = enum_auto()
FewChannels = enum_auto()
#
IteratorAlgorithmTag = {
IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic',
IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized',
IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels',
IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels'
}
IteratorAlgorithmNames = {
IteratorAlgorithm.Analytic: 'analytic',
IteratorAlgorithm.Optimized: 'optimized',
IteratorAlgorithm.FixedChannels: 'fixed_channels',
IteratorAlgorithm.FewChannels: 'few_channels'
}
#
@ -547,4 +659,25 @@ class TensorDescription:
self.alignment = alignment
self.complex_transform = complex_transform
#
class SymmetricTensorDescription:
def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left):
self.element = element
self.layout = layout
self.fill_mode = fill_mode
self.alignment = alignment
self.complex_transform = complex_transform
self.side_mode = side_mode
#
class TriangularTensorDescription:
def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none):
self.element = element
self.layout = layout
self.side_mode = side_mode
self.fill_mode = fill_mode
self.diag_type = diag_type
self.alignment = alignment
self.complex_transform = complex_transform
###################################################################################################