Handle get_masked_trip_count for small length in fmha example (#2292)

* handle get_masked_trip_count for small length

* Update examples/77_blackwell_fmha/collective/fmha_fusion.hpp

Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>

* Update examples/77_blackwell_fmha/collective/fmha_fusion.hpp

Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>

---------

Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>
This commit is contained in:
Taebum Kim
2025-05-31 11:51:18 +09:00
committed by GitHub
parent b9b110a9ea
commit 9d165a3b8e

View File

@ -157,7 +157,8 @@ struct CausalMask : NoMask {
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
}
template<class BlkCoord, class TileShape, class ProblemSize>