CUTLASS 2.9 (#468)
This commit is contained in:
@ -187,6 +187,17 @@ DataTypeSize = {
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
class BlasMode(enum.Enum):
|
||||
symmetric = enum_auto()
|
||||
hermitian = enum_auto()
|
||||
|
||||
#
|
||||
BlasModeTag = {
|
||||
BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric',
|
||||
BlasMode.hermitian: 'cutlass::BlasMode::kHermitian',
|
||||
}
|
||||
|
||||
#
|
||||
class ComplexTransform(enum.Enum):
|
||||
none = enum_auto()
|
||||
@ -341,6 +352,64 @@ ShortComplexLayoutNames = {
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class SideMode(enum.Enum):
|
||||
Left = enum_auto()
|
||||
Right = enum_auto()
|
||||
|
||||
#
|
||||
SideModeTag = {
|
||||
SideMode.Left: 'cutlass::SideMode::kLeft',
|
||||
SideMode.Right: 'cutlass::SideMode::kRight'
|
||||
}
|
||||
|
||||
#
|
||||
ShortSideModeNames = {
|
||||
SideMode.Left: 'ls',
|
||||
SideMode.Right: 'rs'
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class FillMode(enum.Enum):
|
||||
Lower = enum_auto()
|
||||
Upper = enum_auto()
|
||||
|
||||
#
|
||||
FillModeTag = {
|
||||
FillMode.Lower: 'cutlass::FillMode::kLower',
|
||||
FillMode.Upper: 'cutlass::FillMode::kUpper'
|
||||
}
|
||||
|
||||
#
|
||||
ShortFillModeNames = {
|
||||
FillMode.Lower: 'l',
|
||||
FillMode.Upper: 'u'
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class DiagType(enum.Enum):
|
||||
NonUnit = enum_auto()
|
||||
Unit = enum_auto()
|
||||
|
||||
#
|
||||
DiagTypeTag = {
|
||||
DiagType.NonUnit: 'cutlass::DiagType::kNonUnit',
|
||||
DiagType.Unit: 'cutlass::DiagType::kUnit'
|
||||
}
|
||||
|
||||
#
|
||||
ShortDiagTypeNames = {
|
||||
DiagType.NonUnit: 'nu',
|
||||
DiagType.Unit: 'un'
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class OpcodeClass(enum.Enum):
|
||||
Simt = enum_auto()
|
||||
@ -366,12 +435,20 @@ OpcodeClassTag = {
|
||||
#
|
||||
class OperationKind(enum.Enum):
|
||||
Gemm = enum_auto()
|
||||
RankK = enum_auto()
|
||||
Rank2K = enum_auto()
|
||||
Trmm = enum_auto()
|
||||
Symm = enum_auto()
|
||||
Conv2d = enum_auto()
|
||||
Conv3d = enum_auto()
|
||||
|
||||
#
|
||||
OperationKindNames = {
|
||||
OperationKind.Gemm: 'gemm'
|
||||
, OperationKind.RankK: 'rank_k'
|
||||
, OperationKind.Rank2K: 'rank_2k'
|
||||
, OperationKind.Trmm: 'trmm'
|
||||
, OperationKind.Symm: 'symm'
|
||||
, OperationKind.Conv2d: 'conv2d'
|
||||
, OperationKind.Conv3d: 'conv3d'
|
||||
}
|
||||
@ -414,6 +491,7 @@ class GemmKind(enum.Enum):
|
||||
Universal = enum_auto()
|
||||
PlanarComplex = enum_auto()
|
||||
PlanarComplexArray = enum_auto()
|
||||
Grouped = enum_auto()
|
||||
|
||||
#
|
||||
GemmKindNames = {
|
||||
@ -422,6 +500,34 @@ GemmKindNames = {
|
||||
GemmKind.Universal: "gemm",
|
||||
GemmKind.PlanarComplex: "gemm_planar_complex",
|
||||
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
|
||||
GemmKind.Grouped: "gemm_grouped"
|
||||
}
|
||||
|
||||
#
|
||||
class RankKKind(enum.Enum):
|
||||
Universal = enum_auto()
|
||||
|
||||
#
|
||||
RankKKindNames = {
|
||||
RankKKind.Universal: "rank_k"
|
||||
}
|
||||
|
||||
#
|
||||
class TrmmKind(enum.Enum):
|
||||
Universal = enum_auto()
|
||||
|
||||
#
|
||||
TrmmKindNames = {
|
||||
TrmmKind.Universal: "trmm"
|
||||
}
|
||||
|
||||
#
|
||||
class SymmKind(enum.Enum):
|
||||
Universal = enum_auto()
|
||||
|
||||
#
|
||||
SymmKindNames = {
|
||||
SymmKind.Universal: "symm"
|
||||
}
|
||||
|
||||
#
|
||||
@ -483,16 +589,22 @@ ConvKindNames = {
|
||||
class IteratorAlgorithm(enum.Enum):
|
||||
Analytic = enum_auto()
|
||||
Optimized = enum_auto()
|
||||
FixedChannels = enum_auto()
|
||||
FewChannels = enum_auto()
|
||||
|
||||
#
|
||||
IteratorAlgorithmTag = {
|
||||
IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic',
|
||||
IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized',
|
||||
IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels',
|
||||
IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels'
|
||||
}
|
||||
|
||||
IteratorAlgorithmNames = {
|
||||
IteratorAlgorithm.Analytic: 'analytic',
|
||||
IteratorAlgorithm.Optimized: 'optimized',
|
||||
IteratorAlgorithm.FixedChannels: 'fixed_channels',
|
||||
IteratorAlgorithm.FewChannels: 'few_channels'
|
||||
}
|
||||
|
||||
#
|
||||
@ -547,4 +659,25 @@ class TensorDescription:
|
||||
self.alignment = alignment
|
||||
self.complex_transform = complex_transform
|
||||
|
||||
#
|
||||
class SymmetricTensorDescription:
|
||||
def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left):
|
||||
self.element = element
|
||||
self.layout = layout
|
||||
self.fill_mode = fill_mode
|
||||
self.alignment = alignment
|
||||
self.complex_transform = complex_transform
|
||||
self.side_mode = side_mode
|
||||
|
||||
#
|
||||
class TriangularTensorDescription:
|
||||
def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none):
|
||||
self.element = element
|
||||
self.layout = layout
|
||||
self.side_mode = side_mode
|
||||
self.fill_mode = fill_mode
|
||||
self.diag_type = diag_type
|
||||
self.alignment = alignment
|
||||
self.complex_transform = complex_transform
|
||||
|
||||
###################################################################################################
|
||||
|
||||
Reference in New Issue
Block a user