Add GMMA shape m64n40k16 (#1864)

This commit is contained in:
Tri Dao
2024-10-21 17:41:47 -07:00
committed by GitHub
parent 08101d9d0c
commit 5b50a8faaf
3 changed files with 189 additions and 0 deletions

View File

@ -450,6 +450,9 @@ using CLayout_64x16 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _2>>,
using CLayout_64x32 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _4>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x40 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _5>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
using CLayout_64x48 = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2, _6>>,
Stride<Stride<_128,_1,_16>,Stride<_64,_8,_512>>>;
@ -1773,6 +1776,39 @@ struct MMA_Traits<SM90_64x32x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
template <
GMMA::Major tnspA,
GMMA::Major tnspB,
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
>
using SM90_64x40x16_F32F16F16_SS = SM90::GMMA::MMA_64x40x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>;
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x40x16_F32F16F16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = half_t;
using ValTypeB = half_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<40>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 40, 16>;
using CLayout = GMMA::CLayout_64x40;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
template <
GMMA::Major tnspA,
GMMA::Major tnspB,
@ -2846,6 +2882,39 @@ struct MMA_Traits<SM90_64x32x16_F32BF16BF16_RS<tnspA, tnspB, scaleA, scaleB>>
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
template <
GMMA::Major tnspA,
GMMA::Major tnspB,
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
>
using SM90_64x40x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>;
template <GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA, GMMA::ScaleIn scaleB>
struct MMA_Traits<SM90_64x40x16_F32BF16BF16_SS<tnspA, tnspB, scaleA, scaleB>>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;
using FrgTypeA = GMMA::smem_desc<tnspA>;
using FrgTypeB = GMMA::smem_desc<tnspB>;
using Shape_MNK = Shape<_64,Int<40>,_16>;
using ThrID = Layout<_128>;
using ALayout = GMMA::ABLayout< 64, 16>;
using BLayout = GMMA::ABLayout< 40, 16>;
using CLayout = GMMA::CLayout_64x40;
GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
template <
GMMA::Major tnspA,
GMMA::Major tnspB,