[SM90] Change register allocation for TileN=208 to avoid spills (#2219)

With the usual register allocation (producer 40, consumer 232) compiling Gemm
with tile shape 256 x 208 (cooperative) or 128 x 208 (pingpong) show lots of
register spilling (e.g. ~3000 bytes spill). For this case we can change
the register allocation to producer 24, consumer 240, which avoids spills.
This commit is contained in:
Tri Dao
2025-04-21 00:02:30 -04:00
committed by GitHub
parent bb4dd682dd
commit ade6376fa0
2 changed files with 12 additions and 4 deletions

View File

@ -128,8 +128,12 @@ public:
static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
static constexpr int RegsPerThread =
size<0>(TileShape{}) * size<1>(TileShape{}) / NumMMAThreads *
sizeof(ElementAccumulator) / sizeof(uint32_t);
static constexpr bool HeavyRegisterPressure = RegsPerThread >= 208;
static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24;
static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240;
// 1 stage ordered sequence between mainloop and epilogue producer load threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>;

View File

@ -138,8 +138,12 @@ public:
static_assert(MaxThreadsPerBlock == 384, "Pingpong kernel must have 384 threads in total.");
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
static constexpr int RegsPerThread =
size<0>(TileShape{}) * size<1>(TileShape{}) / NumMMAThreads *
sizeof(ElementAccumulator) / sizeof(uint32_t);
static constexpr bool HeavyRegisterPressure = RegsPerThread >= 208;
static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24;
static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240;
// 1 stage ordered sequence between mainloop and epilogue producer load threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>;