Feature/add bottom causal mask (#2480)

* Rebase to latest

* update

* upd

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Update fmha_fusion.hpp

* Update fmha_fusion.hpp

fixed flipped logic for isQBegin

* Update fmha_fusion.hpp

* Avoid use of booleans

The current expression is confusing

* fmt

* Update fmha_fusion.hpp

Reproduce error/fix with: 
./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend

* add test, format

---------

Co-authored-by: Richard Cai <ricai@nvidia.com>
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
This commit is contained in:
Aya Z. Ibrahim
2025-09-18 14:11:23 -07:00
committed by Haicheng Wu
parent 177a82e251
commit c609b86db2
2 changed files with 7 additions and 5 deletions

View File

@ -39,7 +39,8 @@ set_property(
)
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
set(TEST_CAUSAL_00 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
set(TEST_CAUSAL_01 --verify --iterations=0 --b=1 --h=1 --h_k=1 --q=1013 --k=1024 --d=128 --mask=causal --causal-type=qend)
set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen)
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
@ -119,7 +120,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_fmha.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_CAUSAL
TEST_CAUSAL_00
TEST_CAUSAL_01
TEST_VARLEN
TEST_HDIM64
TEST_GQA
@ -222,7 +224,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_mla_fwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_CAUSAL
TEST_CAUSAL_00
TEST_VARLEN
TEST_HDIM64
TEST_GQA

View File

@ -225,8 +225,8 @@ struct CausalMask : NoMask {
if constexpr (IsQBegin) {
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else {
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size)) % get<1>(tile_shape);
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
}
}