Add GMMA shape m64n40k16 (#1864)
This commit is contained in:
@ -450,6 +450,9 @@ using CLayout_64x16 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _2>>,
|
||||
using CLayout_64x32 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _4>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x40 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _5>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
using CLayout_64x48 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _6>>,
|
||||
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
|
||||
|
||||
@ -1773,6 +1776,39 @@ struct MMA_Traits<SM90_64x32x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
|
||||
template <
|
||||
GMMA::Major tnspA,
|
||||
GMMA::Major tnspB,
|
||||
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
||||
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
||||
>
|
||||
using SM90_64x40x16_F32F16F16_SS = SM90::GMMA::MMA_64x40x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>;
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x40x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = half_t;
|
||||
using ValTypeB = half_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<40>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 40, 16>;
|
||||
using CLayout = GMMA::CLayout_64x40;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
|
||||
template <
|
||||
GMMA::Major tnspA,
|
||||
GMMA::Major tnspB,
|
||||
@ -2846,6 +2882,39 @@ struct MMA_Traits<SM90_64x32x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
|
||||
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
|
||||
template <
|
||||
GMMA::Major tnspA,
|
||||
GMMA::Major tnspB,
|
||||
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
||||
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
||||
>
|
||||
using SM90_64x40x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>;
|
||||
|
||||
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
|
||||
struct MMA_Traits<SM90_64x40x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = bfloat16_t;
|
||||
using ValTypeB = bfloat16_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using FrgTypeA = GMMA::smem_desc<tnspA>;
|
||||
using FrgTypeB = GMMA::smem_desc<tnspB>;
|
||||
|
||||
using Shape_MNK = Shape<_64,Int<40>,_16>;
|
||||
using ThrID = Layout<_128>;
|
||||
using ALayout = GMMA::ABLayout< 64, 16>;
|
||||
using BLayout = GMMA::ABLayout< 40, 16>;
|
||||
using CLayout = GMMA::CLayout_64x40;
|
||||
|
||||
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
|
||||
template <
|
||||
GMMA::Major tnspA,
|
||||
GMMA::Major tnspB,
|
||||
|
||||
Reference in New Issue
Block a user