Example 77 add blackwell fmha bwd for MLA shape (#2466)

* Update examples/77_blackwell_fmha/device/fmha_device_bwd.hpp

Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>

* bug fix & use existing value rather than pass one more argument to support different dim in bwd_convert

* Fix casual mask cnt when IsQBegin==false

* bug fix in casual mask backward

* code sync

---------

Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>
This commit is contained in:
Zeyu WANG
2025-07-25 06:41:11 +08:00
committed by GitHub
parent 9a9a579714
commit 0e026982ce
13 changed files with 428 additions and 77 deletions

View File

@ -126,6 +126,7 @@ struct Options {
bool verbose = false;
bool causal = false;
bool causal_q_begin = true;
bool residual = false;
bool varlen = false;
bool persistent = false;
@ -266,6 +267,8 @@ struct Options {
std::string mask;
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
std::string causal_type;
cmd.get_cmd_line_argument<std::string>("causal-type", causal_type, "");
if (mask == "no" || mask == "") {
causal = residual = false;
if (varlen) {
@ -275,6 +278,11 @@ struct Options {
else if (mask == "causal") {
residual = false;
causal = true;
if(causal_type == "qend") {
causal_q_begin = false;
} else {
causal_q_begin = true;
}
}
else if (mask == "residual") {
residual = true;
@ -313,6 +321,7 @@ struct Options {
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|residual|causal> Enables masking\n"
<< " --causal-type=<qbegin|qend> Causal mask type\n"
<< " --persistent Enables persistent scheduler\n"
<< " --varlen Enables variable sequence length\n"
<< " B*Q and B*K become the total sequence length\n"
@ -1078,7 +1087,11 @@ int main_single(int argc, char const **args) {
auto with_mask = [&](auto fn) {
if (options.causal) {
fn(CausalMask{});
if(options.causal_q_begin) {
fn(CausalMask{});
} else {
fn(CausalMask<false>{});
}
}
else if (options.residual) {
fn(ResidualMask{});

View File

@ -816,7 +816,7 @@ struct BwdRunner {
runtime_ms /= static_cast<float>(options.iterations);
double flops = 2.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask> ? 0.5 : 1.0);
double flops = 2.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask<false>> || std::is_same_v<ActiveMask, CausalForBackwardMask<true>> ? 0.5 : 1.0);
flops *= static_cast<double>(get<0>(problem_shape));
flops *= static_cast<double>(get<1>(problem_shape));
flops *= (3 * static_cast<double>(get<2>(problem_shape)) + 2 * static_cast<double>(get<3>(problem_shape)));

View File

@ -80,6 +80,7 @@ struct Options {
int iterations = 3;
bool verify = false;
bool verbose = false;
bool is_fused_reduction = false;
int sm_count = 0;
@ -139,9 +140,12 @@ struct Options {
if (b == 0) b = 1;
cmd.get_cmd_line_argument("split_kv", split_kv, defaults.split_kv);
if (split_kv == 0) {
split_kv = 1;
}
cmd.get_cmd_line_argument("page", page, defaults.page);
cmd.get_cmd_line_argument("spread", spread, defaults.spread);
cmd.get_cmd_line_argument("is_var_split_kv", is_var_split_kv, false);
is_var_split_kv = cmd.check_cmd_line_flag("var_split_kv");
if (page == -1) {
is_var_split_kv = false;
}
@ -149,6 +153,10 @@ struct Options {
if (is_var_split_kv == true) {
split_kv = max_split_kv;
}
is_fused_reduction = cmd.check_cmd_line_flag("fuse_reduction");
if (split_kv == 1) {
is_fused_reduction = false;
}
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
@ -176,6 +184,8 @@ struct Options {
<< " --iterations=<int> Benchmarking iterations\n"
<< " --spread=<float> Relative spread away from K for paging\n"
<< " --split_kv=<int> Split KV factor\n"
<< " --fused_reduction Fuse the reduction operation\n"
<< " --var_split_kv Use varying split KV factor\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --sm-count Sets SM count rather than querying it\n"
@ -514,7 +524,8 @@ struct Runner {
stride_LSE},
hw_info,
options.split_kv,
options.is_var_split_kv ? block_split_kv.get() : nullptr
options.is_var_split_kv ? block_split_kv.get() : nullptr,
options.is_fused_reduction
};
if (options.split_kv < 0 && !options.is_var_split_kv) {
Operation::set_split_kv(arguments);
@ -724,13 +735,17 @@ void run_mla(Options const & options, cutlass::KernelHardwareInfo const& hw_info
// Persistent Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
// Individual Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
if (!options.is_fused_reduction || options.split_kv == 1) {
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
}
#elif FP16
name += " fp16";
// Persistent Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
// Individual Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
if (!options.is_fused_reduction || options.split_kv == 1) {
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
}
#endif
}

View File

@ -90,6 +90,7 @@ struct Options {
bool verbose = false;
bool causal = false;
bool causal_q_begin = true;
bool residual = false;
bool varlen = false;
bool persistent = false;
@ -231,6 +232,8 @@ struct Options {
std::string mask;
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
std::string causal_type;
cmd.get_cmd_line_argument<std::string>("causal-type", causal_type, "");
if (mask == "no" || mask == "") {
causal = residual = false;
if (varlen) {
@ -240,6 +243,11 @@ struct Options {
else if (mask == "causal") {
residual = false;
causal = true;
if(causal_type == "qend") {
causal_q_begin = false;
} else {
causal_q_begin = true;
}
}
else if (mask == "residual") {
residual = true;
@ -279,6 +287,7 @@ struct Options {
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|residual|causal> Enables masking\n"
<< " --causal-type=<qbegin|qend> Causal mask type\n"
<< " --persistent Enables persistent scheduler\n"
<< " --varlen Enables variable sequence length\n"
<< " B*Q and B*K become the total sequence length\n"
@ -1013,7 +1022,11 @@ int main_single(int argc, char const **args) {
auto with_mask = [&](auto fn) {
if (options.causal) {
fn(CausalMask<false>{});
if(options.causal_q_begin) {
fn(CausalMask{});
} else {
fn(CausalMask<false>{});
}
}
else if (options.residual) {
fn(ResidualMask{});

View File

@ -59,6 +59,14 @@ set(TEST_VARLEN_11 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2
set(TEST_VARLEN_12 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
set(TEST_VARLEN_13 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
set(TEST_VARLEN_14 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
set(TEST_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
set(TEST_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
set(TEST_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
set(TEST_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
set(TEST_MLA_FWD_VARLEN_00 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_01 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
@ -75,6 +83,13 @@ set(TEST_MLA_FWD_VARLEN_11 --verify --varlen --mask=causal,residual --dl=128 --d
set(TEST_MLA_FWD_VARLEN_12 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=177:766 --varlen-k=257:845)
set(TEST_MLA_FWD_VARLEN_13 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=177:0:479 --varlen-k=257:0:766)
set(TEST_MLA_FWD_VARLEN_14 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
set(TEST_MLA_FWD_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
set(TEST_MLA_FWD_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
set(TEST_MLA_FWD_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
set(TEST_MLA_FWD_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
@ -87,6 +102,9 @@ set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify)
set(TEST_BWD_MLA_BASIC --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=no)
set(TEST_BWD_MLA_VARLEN --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=residual --varlen)
set(TEST_MLA_SEP_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --verify)
set(TEST_MLA_FUSE_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --fuse_reduction --verify)
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
foreach(PREC fp8 fp16)
@ -116,6 +134,12 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
TEST_VARLEN_12
TEST_VARLEN_13
TEST_VARLEN_14
TEST_VARLEN_15
TEST_VARLEN_16
TEST_VARLEN_17
TEST_VARLEN_18
TEST_VARLEN_19
TEST_VARLEN_20
)
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
@ -139,6 +163,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_mla.cu
TEST_COMMAND_OPTIONS
TEST_MLA_BASIC
TEST_MLA_SEP_REDUCTION
TEST_MLA_FUSE_REDUCTION
)
target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO})
@ -149,6 +175,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_mla.cu
TEST_COMMAND_OPTIONS
TEST_MLA_BASIC
TEST_MLA_SEP_REDUCTION
TEST_MLA_FUSE_REDUCTION
)
target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
@ -207,6 +235,12 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
TEST_MLA_FWD_VARLEN_12
TEST_MLA_FWD_VARLEN_13
TEST_MLA_FWD_VARLEN_14
TEST_MLA_FWD_VARLEN_15
TEST_MLA_FWD_VARLEN_16
TEST_MLA_FWD_VARLEN_17
TEST_MLA_FWD_VARLEN_18
TEST_MLA_FWD_VARLEN_19
TEST_MLA_FWD_VARLEN_20
)
target_include_directories(77_blackwell_mla_fwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_fwd_${PREC} PRIVATE ${PREC_MACRO})

View File

@ -203,13 +203,12 @@ struct CausalMask : NoMask {
// See note below on different ways to think about causal attention
// Again, we'd add the offset_q into the max_blocks_q calculation
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
if constexpr (IsQBegin) {
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
} else {
const int offset_q = get<1>(problem_size) - get<0>(problem_size);
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
}
@ -222,12 +221,12 @@ struct CausalMask : NoMask {
TileShape const& tile_shape,
ProblemSize const& problem_size) {
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
if constexpr (IsQBegin) {
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else {
const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape);
return ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape));
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
}
}
@ -277,9 +276,10 @@ struct CausalMask : NoMask {
}
};
struct CausalForBackwardMask : CausalMask<true>, ResidualMaskForBackward {
template<bool kIsQBegin = true>
struct CausalForBackwardMask : CausalMask<kIsQBegin>, ResidualMaskForBackward {
using Base = CausalMask<true>;
using Base = CausalMask<kIsQBegin>;
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
@ -296,10 +296,15 @@ struct CausalForBackwardMask : CausalMask<true>, ResidualMaskForBackward {
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to add an offset like so:
// get<0>(pos) + offset_q < get<1>(pos)
int offset_q = 0;
if constexpr (!kIsQBegin) {
offset_q = get<1>(problem_size) - get<0>(problem_size);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
bool masked = (get<0>(pos) < get<1>(pos)) || !elem_less(pos, problem_size);
bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size);
if (masked) {
acc_qk(i) = -INFINITY;
}

View File

@ -100,7 +100,7 @@ public:
cutlass::fmha::kernel::FmhaKernelBwdConvert<ProblemShape, Element, ElementAccumulator>
>;
using OperationNormal= cutlass::fmha::device::FMHA<
using OperationMha= cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<
ProblemShape, Element, ElementAccumulator, TileShape, Mask
>
@ -112,7 +112,7 @@ public:
>
>;
using Operation = std::conditional_t<IsMla, OperationMla, OperationNormal>;
using Operation = std::conditional_t<IsMla, OperationMla, OperationMha>;
using Kernel = typename Operation::Kernel;

View File

@ -127,7 +127,11 @@ public:
int waves = ceil_div(B * split_heur, sm_count);
int k_waves = ceil_div(max_splits, split_heur);
int split_wave_aware = ceil_div(max_splits, k_waves);
args.split_kv = split_wave_aware;
if (args.is_fused_reduction && split_wave_aware > 1) {
args.split_kv = std::min(split_wave_aware, static_cast<int>(sm_count/2));
} else {
args.split_kv = split_wave_aware;
}
}
/// Determines whether the GEMM can execute the given problem.
@ -273,11 +277,33 @@ public:
CUTLASS_TRACE_HOST("MLA::run()");
dim3 const block = Kernel::get_block_shape();
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
auto [H, K, D, B] = params.fmha_params.problem_shape;
auto [D_latent, D_rope] = D;
// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;
Status launch_result;
if (params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) {
auto result = cudaMemsetAsync(params.fmha_params.epilogue.ptr_o, 0, sizeof(typename Kernel::ElementOut) * H * D_latent * B, stream);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaMemsetAsync() returned error: "
<< cudaGetErrorString(result));
return Status::kErrorInternal;
}
auto total_bytes = H * B * (sizeof(int) + sizeof(typename Kernel::ElementLSE)) + 2 * B * sizeof(int);
uint8_t* ws = reinterpret_cast<uint8_t*>(params.fmha_params.epilogue.ptr_lse_exchange_buff);
result = cudaMemsetAsync(ws, 0, total_bytes, stream);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaMemsetAsync() returned error: "
<< cudaGetErrorString(result));
return Status::kErrorInternal;;
}
}
// Use extended launch API only for mainloops that use it
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
@ -298,7 +324,7 @@ public:
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
if (params.reduction_params.split_kv > 1) {
if (!params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) {
// launch reduction kernel
dim3 const block = ReductionKernel::get_block_shape();
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);

View File

@ -1245,9 +1245,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
};
bool leading_causal_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
int kv_left = get<1>(blk_coord) * TileShapeK{};
int kv_right = kv_left + TileShapeK{} - 1;
int q_left = iter_index * TileShapeQ{} + offset;
int q_right = q_left + TileShapeQ{} - 1;
leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));
}
bool trailing_residual_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
@ -1683,9 +1690,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
);
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask> ||
std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
}
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
return;

View File

@ -1230,9 +1230,16 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
};
bool leading_causal_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
int kv_left = get<1>(blk_coord) * TileShapeK{};
int kv_right = kv_left + TileShapeK{} - 1;
int q_left = iter_index * TileShapeQ{} + offset;
int q_right = q_left + TileShapeQ{} - 1;
leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));
}
bool trailing_residual_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
@ -1677,9 +1684,11 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
);
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
}
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
return;

View File

@ -101,8 +101,17 @@ struct Sm100FmhaMlaReductionKernel {
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
if (params.split_kv <= 1) return;
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
if (local_split_kv == 1) return;
__shared__ ElementAcc sLseScale[kMaxSplits];
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
@ -113,12 +122,6 @@ struct Sm100FmhaMlaReductionKernel {
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
Shape<_1>{}, Stride<_1>{});
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
int warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0) {
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
@ -130,17 +133,18 @@ struct Sm100FmhaMlaReductionKernel {
const int split = i * 32 + threadIdx.x;
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
}
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
lse_max = max(lse_max, local_lse[i]);
lse_max = fmax(local_lse[i], lse_max);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
lse_max = fmax(__shfl_xor_sync(0xffffffff, lse_max, offset), lse_max);
}
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
ElementAcc sum_lse = 0;

View File

@ -36,6 +36,7 @@
#include "cute/tensor.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cutlass/barrier.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm80.h"
@ -44,6 +45,7 @@
#include "gather_tensor.hpp" // from examples/common
#include "common/pow_2.hpp"
#include "sm100_mla_tile_scheduler.hpp"
namespace cutlass::fmha::kernel {
@ -87,8 +89,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
using TileShapeR = tuple_element_t<1, TileShapeD>;
static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim");
using ProblemShape = Shape<TileShapeH, int, TileShapeD, int>;
using TensorStride = Stride<int64_t, _1, int64_t>;
using ProblemShape = Shape<TileShapeH, int, TileShapeD, int>;
using TensorStride = Stride<int64_t, _1, int64_t>;
using TmemAllocator = cute::conditional_t<kIs2Sm, cute::TMEM::Allocator2Sm, cute::TMEM::Allocator1Sm>;
static_assert(TileShapeH{} == 128);
@ -181,10 +183,13 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB;
using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB;
using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, make_shape(Int<IterationsPV_K>{}, _2{})));
using SmemLayoutOut = decltype(take<0,2>(typename CollectiveMmaQK::CtaShape_MNK{}));
using TileShapeAcc = decltype(take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}));
static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutKC{})) * cute::sizeof_bits_v<Element>);
static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutVC{})) * cute::sizeof_bits_v<Element>);
// pre-condition for overlapped smem staging
static_assert(kBytesLoadKC == kBytesLoadVC);
static_assert(StagesQK == StagesPV);
@ -226,7 +231,10 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutKC>> smem_kc;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutVC>> smem_vc;
};
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutP>> smem_p;
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutP>> smem_p;
alignas(2048) cute::array<ElementOut, size(TileShapeAcc{})> smem_acc;
};
};
struct SharedStorage {
@ -280,6 +288,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
KernelHardwareInfo hw_info;
int split_kv = -1;
int* ptr_split_kv = nullptr;
bool is_fused_reduction = false;
};
using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A;
@ -288,6 +297,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B;
using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B;
using GmemLayout = decltype(make_layout(Shape<int,int,int>{}, Stride<int64_t, _1, int64_t>{}));
using SmemLayout = decltype(make_layout(TileShapeAcc{}, LayoutRight{}));
using TmaReduceSum = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},
make_tensor(recast_ptr<ElementOut>(nullptr), GmemLayout{}), SmemLayout{}));
struct MainloopParams {
TmaLoadQLatent tma_load_q_latent;
TmaLoadQRope tma_load_q_rope;
@ -306,6 +321,10 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
Stride<_1, int> stride_lse;
Stride<_1, int> stride_lse_acc;
ElementAcc output_scale = 1.0f;
ElementLSE* ptr_lse_exchange_buff = nullptr;
int* ptr_lse_max_exchange_buff = nullptr;
int* ptr_lock = nullptr; // semaphore
TmaReduceSum tma_reduce_sum;
};
struct Params {
@ -316,6 +335,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
typename TileScheduler::Params tile_scheduler;
int split_kv = -1;
int* ptr_split_kv = nullptr;
bool is_fused_reduction = false;
};
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
@ -380,11 +400,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
epilogue_params.ptr_o = args.epilogue.ptr_o;
epilogue_params.stride_o = args.epilogue.stride_o;
epilogue_params.ptr_lse = args.epilogue.ptr_lse;
epilogue_params.ptr_lse = args.epilogue.ptr_lse;
epilogue_params.stride_lse = args.epilogue.stride_lse;
epilogue_params.output_scale = args.epilogue.output_scale;
epilogue_params.tma_reduce_sum = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(recast_ptr<ElementOut>(args.epilogue.ptr_o), make_layout(make_shape(H, L, B), args.epilogue.stride_o)), SmemLayout{});
if (args.split_kv > 1) {
if (!args.is_fused_reduction && args.split_kv > 1) {
ElementAcc* ptr_o_acc = reinterpret_cast<ElementAcc*>(workspace);
ElementLSE* ptr_lse_acc = reinterpret_cast<ElementLSE*>(ptr_o_acc + H * L * args.split_kv * B);
epilogue_params.ptr_o_acc = ptr_o_acc;
@ -392,10 +413,18 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
epilogue_params.stride_o_acc = make_tuple(static_cast<int64_t>(0 + L) * args.split_kv, _1{}, static_cast<int64_t>(0 + H * L) * args.split_kv);
epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv);
} else if (args.is_fused_reduction && args.split_kv > 1) {
ElementLSE* ptr_lse_exchange_buff = reinterpret_cast<ElementLSE*>(workspace);
epilogue_params.ptr_lse_exchange_buff = ptr_lse_exchange_buff;
int* ptr_lse_max_exchange_buff = reinterpret_cast<int*>(ptr_lse_exchange_buff + H * B);
epilogue_params.ptr_lse_max_exchange_buff = ptr_lse_max_exchange_buff;
int* ptr_lock = ptr_lse_max_exchange_buff + H * B;
epilogue_params.ptr_lock = ptr_lock;
}
return {args.problem_shape, args.mainloop, epilogue_params, mainloop_params,
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), args.split_kv, args.ptr_split_kv};
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv),
args.split_kv, args.ptr_split_kv, args.is_fused_reduction};
}
static size_t get_workspace_size(Arguments const& args) {
@ -403,10 +432,29 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
auto split_kv = args.split_kv;
return (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B;
size_t workspace_size {0};
if (args.is_fused_reduction && args.split_kv > 1) {
// one exchange buffer for LSE max and another buffer for total LSE
// two locks per batch, frist lock is for CTA0 / H=0..63 and the second is for CTA1 / H=64..127
workspace_size = H * B * (sizeof(int) + sizeof(ElementLSE)) + 2 * B * sizeof(int);
} else if (!args.is_fused_reduction && args.split_kv > 1) {
workspace_size = (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B;
}
return workspace_size;
}
static Status initialize_workspace(
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
Arguments const& args, void* ws, cudaStream_t stream) {
auto workspace_size = get_workspace_size(args);
if (args.is_fused_reduction && args.split_kv > 1) {
auto result = cudaMemsetAsync(ws, 0, workspace_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaMemsetAsync() returned error: "
<< cudaGetErrorString(result));
return Status::kErrorInternal;;
}
}
return Status::kSuccess;
}
@ -448,6 +496,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
std::cerr << __FILE__ << "(" << __LINE__ << "): split-k off\n";
return false;
}
if (args.is_fused_reduction && args.split_kv > 1) {
if (2 * args.split_kv > args.hw_info.sm_count ||
std::is_same_v<TileScheduler, Sm100MlaIndividualTileScheduler>) {
return false;
}
}
return true;
}
@ -746,7 +800,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
pipeline_mma_s, pipeline_mma_s_consumer_state,
pipeline_p_mma, pipeline_p_mma_producer_state,
pipeline_mma_o, pipeline_mma_o_consumer_state,
local_split_kv
local_split_kv,
params.is_fused_reduction
);
}
@ -1777,7 +1832,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
if (epilogue_args.ptr_o_acc != nullptr) {
if (split_kv > 1) {
using ElementOutAcc = ElementAcc;
constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v<ElementOutAcc>;
Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), make_shape(H, D_latent, B), epilogue_args.stride_o_acc);
@ -1806,16 +1862,20 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
copy(tTR_rO_src, tR2G_rO_dst);
// compute LSE
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
if (get<1>(cta_coord) == 0) {
if (epilogue_args.ptr_lse != nullptr) {
// compute LSE
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
// store LSE
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc);
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
// for 2x2 dp, this must be conditional and the index is wrong
if (! kIs2Sm || (threadIdx.x < 64))
{
gLSE(threadIdx.x) = lse;
// store LSE
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc);
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
// for 2x2 dp, this must be conditional and the index is wrong
if (! kIs2Sm || (threadIdx.x < 64))
{
gLSE(threadIdx.x) = lse;
}
}
}
}
else {
@ -1845,24 +1905,165 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
copy(tTR_rO_src, tR2G_rO_dst);
if (get<1>(cta_coord) == 0) {
if (epilogue_args.ptr_lse != nullptr) {
// compute LSE
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
if (epilogue_args.ptr_lse != nullptr) {
// compute LSE
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
// store LSE
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse);
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
// store LSE
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse);
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
// for 2x2 dp, this must be conditional and the index is wrong
if (! kIs2Sm || (threadIdx.x < 64))
{
gLSE(threadIdx.x) = lse;
// for 2x2 dp, this must be conditional and the index is wrong
if (! kIs2Sm || (threadIdx.x < 64))
{
gLSE(threadIdx.x) = lse;
}
}
}
}
}
template<class BlkCoord>
CUTLASS_DEVICE ElementLSE epilogue_lse_reduction(
ElementAcc& row_max,
ElementAcc& row_sum,
BlkCoord const& cta_coord,
ProblemShape const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueParams const& epilogue_args,
int const& local_split_kv) {
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{});
constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp;
using Sync = cutlass::detail::NamedBarrierSync<kNumThreads, kNamedBarrierExchange>;
auto wait = [](int* lock, int count) {
__threadfence();
if (threadIdx.x == 0) {
atomicAdd(lock, 1);
while (atomicCAS(lock, count, count) != count) {};
}
__threadfence();
Sync::sync();
};
const ElementLSE lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
Tensor mLSE_max_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_max_exchange_buff), make_shape(H, B), epilogue_args.stride_lse);
Tensor gLSE_max_buff = local_tile(mLSE_max_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
int* local_lock = epilogue_args.ptr_lock + get<0>(cta_coord) + 2 * get<2>(cta_coord);
if (! kIs2Sm || (threadIdx.x < 64)) {
atomicMax(&(gLSE_max_buff(threadIdx.x)), __float2int_rn(lse));
}
wait(local_lock, local_split_kv);
auto global_lse_max = static_cast<ElementLSE>(gLSE_max_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x));
Tensor mLSE_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_exchange_buff), make_shape(H, B), epilogue_args.stride_lse);
Tensor gLSE_buff = local_tile(mLSE_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
if (! kIs2Sm || (threadIdx.x < 64)) {
atomicAdd(&(gLSE_buff(threadIdx.x)), expf(lse - global_lse_max));
}
wait(local_lock, 2*local_split_kv);
const auto sum_lse = gLSE_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x);
const auto global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementLSE>::infinity() :
cutlass::fast_log(sum_lse) + global_lse_max;
const auto lse_scale = expf(lse - global_lse);
if (epilogue_args.ptr_lse != nullptr) {
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse);
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
// write out the global LSE
if (! kIs2Sm || (threadIdx.x < 64)) {
gLSE(threadIdx.x) = global_lse;
}
}
return lse_scale;
}
template<class BlkCoord>
CUTLASS_DEVICE void epilogue_reduction(
ElementAcc& row_max,
ElementAcc& row_sum,
BlkCoord const& blk_coord,
ProblemShape const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueParams const& epilogue_args,
TensorStorage& shared_tensors,
int const& local_split_kv,
ElementLSE const& lse_scale) {
constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp;
using Sync = cutlass::detail::NamedBarrierSync<kNumThreads, kNamedBarrierExchange>;
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{};
TiledMmaPV tiled_mma_pv;
Tensor tOtO = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{})));
CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{});
CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{});
using EpilogueLinearCombination = cutlass::epilogue::thread::LinearCombination<ElementOut, 1, ElementAcc, ElementAcc, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
EpilogueLinearCombination epilogue_op({epilogue_args.output_scale / row_sum * lse_scale});
CUTLASS_PRAGMA_UNROLL
for(int k = 0; k < IterationsPV_N; ++k) {
auto cta_coord = replace<1>(blk_coord, k);
uint32_t tmem_o = uint32_t(TmemAllocation::kO0) + k * uint32_t(TmemAllocation::kSizeAccO);
tOtO.data() = tmem_o;
Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{});
Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o);
Tensor gO = local_tile(mO, TileShapeAcc{}, take<0,3>(cta_coord));
auto tiled_t2r = make_tmem_copy(load_op, tAcc);
auto thread_idx = threadIdx.x % size(tiled_t2r);
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
Tensor tTR_gO = thread_t2r.partition_D(gO);
Tensor tTR_rAcc = make_tensor<ElementAcc>(shape(tTR_gO));
Tensor tTR_tAcc = thread_t2r.partition_S(tAcc);
copy(tiled_t2r, tTR_tAcc, tTR_rAcc);
Tensor sO = make_tensor(make_smem_ptr(reinterpret_cast<ElementOut*>(shared_tensors.smem_acc.begin())), SmemLayout{});
Tensor tTR_sO = thread_t2r.partition_D(sO);
Sync::sync();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rAcc); i++) {
tTR_sO(i) = epilogue_op(tTR_rAcc(i));
}
tma_store_fence();
Sync::sync();
auto tma_reduce_sum_per_cta = epilogue_args.tma_reduce_sum.get_slice(_0{});
auto gmem_tensor_coord = epilogue_args.tma_reduce_sum.get_tma_tensor(shape(mO));
auto gmem_tensor_coord_per_cta = local_tile(gmem_tensor_coord, TileShapeAcc{}, take<0,3>(cta_coord));
if (threadIdx.x % kNumThreads == 0) {
copy(epilogue_args.tma_reduce_sum,
tma_reduce_sum_per_cta.partition_S(sO),
tma_reduce_sum_per_cta.partition_D(gmem_tensor_coord_per_cta));
tma_store_arrive();
}
tma_store_wait<0>();
}
}
template<class CtaCoord>
CUTLASS_DEVICE void compute(
@ -1877,7 +2078,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
typename PipelineP::PipelineState& pipeline_p_mma_producer_state,
PipelineO& pipeline_mma_o,
typename PipelineO::PipelineState& pipeline_mma_o_consumer_state,
int const& split_kv) {
int const& split_kv,
bool const& is_fused_reduction) {
auto [H, K, D, B] = problem_shape;
@ -1987,17 +2189,38 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive();
// epilogue
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < IterationsPV_N; j++) {
epilogue(
row_max, row_sum,
replace<1>(cta_coord, j), problem_shape,
mainloop_args, epilogue_args, shared_tensors,
uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), split_kv
const int actual_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
if (!is_fused_reduction || actual_split_kv == 1) {
// epilogue
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < IterationsPV_N; j++) {
epilogue(
row_max, row_sum,
replace<1>(cta_coord, j), problem_shape,
mainloop_args, epilogue_args, shared_tensors,
uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO),
actual_split_kv
);
}
} else {
const ElementLSE lse_scale =
epilogue_lse_reduction(
row_max, row_sum,
cta_coord,
problem_shape,
mainloop_args, epilogue_args,
actual_split_kv
);
epilogue_reduction(row_max, row_sum,
cta_coord,
problem_shape,
mainloop_args, epilogue_args,
shared_tensors,
actual_split_kv,
lse_scale
);
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state);
++pipeline_mma_o_consumer_state;

View File

@ -142,8 +142,8 @@ struct Sm100MlaPersistentTileScheduler {
int block_decode = block_idx;
int m_block, bidb, n_split_kv;
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
return make_coord(m_block, _0{}, bidb, n_split_kv);
}