Handle get_masked_trip_count for small length in fmha example (#2292)
* handle get_masked_trip_count for small length * Update examples/77_blackwell_fmha/collective/fmha_fusion.hpp Co-authored-by: Vijay Thakkar <vijaythakkar@me.com> * Update examples/77_blackwell_fmha/collective/fmha_fusion.hpp Co-authored-by: Vijay Thakkar <vijaythakkar@me.com> --------- Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>
This commit is contained in:
@ -157,7 +157,8 @@ struct CausalMask : NoMask {
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
|
||||
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
|
||||
Reference in New Issue
Block a user