Feature/add bottom causal mask (#2480)
* Rebase to latest * update * upd Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Update fmha_fusion.hpp * Update fmha_fusion.hpp fixed flipped logic for isQBegin * Update fmha_fusion.hpp * Avoid use of booleans The current expression is confusing * fmt * Update fmha_fusion.hpp Reproduce error/fix with: ./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend * add test, format --------- Co-authored-by: Richard Cai <ricai@nvidia.com> Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
This commit is contained in:
@ -39,7 +39,8 @@ set_property(
|
||||
)
|
||||
|
||||
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
|
||||
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
|
||||
set(TEST_CAUSAL_00 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
|
||||
set(TEST_CAUSAL_01 --verify --iterations=0 --b=1 --h=1 --h_k=1 --q=1013 --k=1024 --d=128 --mask=causal --causal-type=qend)
|
||||
set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen)
|
||||
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
|
||||
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
|
||||
@ -119,7 +120,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_fmha.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
TEST_CAUSAL
|
||||
TEST_CAUSAL_00
|
||||
TEST_CAUSAL_01
|
||||
TEST_VARLEN
|
||||
TEST_HDIM64
|
||||
TEST_GQA
|
||||
@ -222,7 +224,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_mla_fwd.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
TEST_CAUSAL
|
||||
TEST_CAUSAL_00
|
||||
TEST_VARLEN
|
||||
TEST_HDIM64
|
||||
TEST_GQA
|
||||
|
||||
@ -225,8 +225,8 @@ struct CausalMask : NoMask {
|
||||
if constexpr (IsQBegin) {
|
||||
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
|
||||
} else {
|
||||
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
|
||||
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
|
||||
const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size)) % get<1>(tile_shape);
|
||||
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user