Set EpiTile correctly when TileN is not divisible by 32 (#2220)

If TileN is not divisible by 32 (e.g, 208), by default EpiTile would be set
to 128 x 32, which does not compile as TileN is required to divide EpiTileN
This commit is contained in:
Tri Dao
2025-04-21 00:02:51 -04:00
committed by GitHub
parent ade6376fa0
commit 81a43e6d92

View File

@ -116,13 +116,13 @@ sm90_compute_tile_shape_or_override() {
auto epi_tile = [&] () {
if constexpr (detail::sm90_is_cooperative_v<Schedule>) {
auto tile_m = cute::min(_128{}, size<0>(TileShape_MNK{}));
auto tile_n = cute::min(_32{}, size<1>(TileShape_MNK{}));
auto tile_n = cute::gcd(cute::min(_32{}, size<1>(TileShape_MNK{})), size<1>(TileShape_MNK{}));
return make_shape(tile_m, tile_n);
}
else if constexpr (detail::sm90_is_warp_specialized_v<Schedule>) {
constexpr int N_perf = sizeof_bits_v<ElementD> == 8 ? 64 : 32;
auto tile_m = cute::min(_64{}, size<0>(TileShape_MNK{}));
auto tile_n = cute::min(Int<N_perf>{}, size<1>(TileShape_MNK{}));
auto tile_n = cute::gcd(cute::min(Int<N_perf>{}, size<1>(TileShape_MNK{})), size<1>(TileShape_MNK{}));
return make_shape(tile_m, tile_n);
}
else {