Set EpiTile correctly when TileN is not divisible by 32 (#2220)
If TileN is not divisible by 32 (e.g, 208), by default EpiTile would be set to 128 x 32, which does not compile as TileN is required to divide EpiTileN
This commit is contained in:
@ -116,13 +116,13 @@ sm90_compute_tile_shape_or_override() {
|
||||
auto epi_tile = [&] () {
|
||||
if constexpr (detail::sm90_is_cooperative_v<Schedule>) {
|
||||
auto tile_m = cute::min(_128{}, size<0>(TileShape_MNK{}));
|
||||
auto tile_n = cute::min(_32{}, size<1>(TileShape_MNK{}));
|
||||
auto tile_n = cute::gcd(cute::min(_32{}, size<1>(TileShape_MNK{})), size<1>(TileShape_MNK{}));
|
||||
return make_shape(tile_m, tile_n);
|
||||
}
|
||||
else if constexpr (detail::sm90_is_warp_specialized_v<Schedule>) {
|
||||
constexpr int N_perf = sizeof_bits_v<ElementD> == 8 ? 64 : 32;
|
||||
auto tile_m = cute::min(_64{}, size<0>(TileShape_MNK{}));
|
||||
auto tile_n = cute::min(Int<N_perf>{}, size<1>(TileShape_MNK{}));
|
||||
auto tile_n = cute::gcd(cute::min(Int<N_perf>{}, size<1>(TileShape_MNK{})), size<1>(TileShape_MNK{}));
|
||||
return make_shape(tile_m, tile_n);
|
||||
}
|
||||
else {
|
||||
|
||||
Reference in New Issue
Block a user