Example 77 add blackwell fmha bwd for MLA shape (#2466)
* Update examples/77_blackwell_fmha/device/fmha_device_bwd.hpp Co-authored-by: Vijay Thakkar <vijaythakkar@me.com> * bug fix & use existing value rather than pass one more argument to support different dim in bwd_convert * Fix casual mask cnt when IsQBegin==false * bug fix in casual mask backward * code sync --------- Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>
This commit is contained in:
@ -126,6 +126,7 @@ struct Options {
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
bool causal_q_begin = true;
|
||||
bool residual = false;
|
||||
bool varlen = false;
|
||||
bool persistent = false;
|
||||
@ -266,6 +267,8 @@ struct Options {
|
||||
|
||||
std::string mask;
|
||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||
std::string causal_type;
|
||||
cmd.get_cmd_line_argument<std::string>("causal-type", causal_type, "");
|
||||
if (mask == "no" || mask == "") {
|
||||
causal = residual = false;
|
||||
if (varlen) {
|
||||
@ -275,6 +278,11 @@ struct Options {
|
||||
else if (mask == "causal") {
|
||||
residual = false;
|
||||
causal = true;
|
||||
if(causal_type == "qend") {
|
||||
causal_q_begin = false;
|
||||
} else {
|
||||
causal_q_begin = true;
|
||||
}
|
||||
}
|
||||
else if (mask == "residual") {
|
||||
residual = true;
|
||||
@ -313,6 +321,7 @@ struct Options {
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||
<< " --causal-type=<qbegin|qend> Causal mask type\n"
|
||||
<< " --persistent Enables persistent scheduler\n"
|
||||
<< " --varlen Enables variable sequence length\n"
|
||||
<< " B*Q and B*K become the total sequence length\n"
|
||||
@ -1078,7 +1087,11 @@ int main_single(int argc, char const **args) {
|
||||
|
||||
auto with_mask = [&](auto fn) {
|
||||
if (options.causal) {
|
||||
fn(CausalMask{});
|
||||
if(options.causal_q_begin) {
|
||||
fn(CausalMask{});
|
||||
} else {
|
||||
fn(CausalMask<false>{});
|
||||
}
|
||||
}
|
||||
else if (options.residual) {
|
||||
fn(ResidualMask{});
|
||||
|
||||
@ -816,7 +816,7 @@ struct BwdRunner {
|
||||
|
||||
runtime_ms /= static_cast<float>(options.iterations);
|
||||
|
||||
double flops = 2.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask> ? 0.5 : 1.0);
|
||||
double flops = 2.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask<false>> || std::is_same_v<ActiveMask, CausalForBackwardMask<true>> ? 0.5 : 1.0);
|
||||
flops *= static_cast<double>(get<0>(problem_shape));
|
||||
flops *= static_cast<double>(get<1>(problem_shape));
|
||||
flops *= (3 * static_cast<double>(get<2>(problem_shape)) + 2 * static_cast<double>(get<3>(problem_shape)));
|
||||
|
||||
@ -80,6 +80,7 @@ struct Options {
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
bool is_fused_reduction = false;
|
||||
|
||||
int sm_count = 0;
|
||||
|
||||
@ -139,9 +140,12 @@ struct Options {
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("split_kv", split_kv, defaults.split_kv);
|
||||
if (split_kv == 0) {
|
||||
split_kv = 1;
|
||||
}
|
||||
cmd.get_cmd_line_argument("page", page, defaults.page);
|
||||
cmd.get_cmd_line_argument("spread", spread, defaults.spread);
|
||||
cmd.get_cmd_line_argument("is_var_split_kv", is_var_split_kv, false);
|
||||
is_var_split_kv = cmd.check_cmd_line_flag("var_split_kv");
|
||||
if (page == -1) {
|
||||
is_var_split_kv = false;
|
||||
}
|
||||
@ -149,6 +153,10 @@ struct Options {
|
||||
if (is_var_split_kv == true) {
|
||||
split_kv = max_split_kv;
|
||||
}
|
||||
is_fused_reduction = cmd.check_cmd_line_flag("fuse_reduction");
|
||||
if (split_kv == 1) {
|
||||
is_fused_reduction = false;
|
||||
}
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
@ -176,6 +184,8 @@ struct Options {
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --spread=<float> Relative spread away from K for paging\n"
|
||||
<< " --split_kv=<int> Split KV factor\n"
|
||||
<< " --fused_reduction Fuse the reduction operation\n"
|
||||
<< " --var_split_kv Use varying split KV factor\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
@ -514,7 +524,8 @@ struct Runner {
|
||||
stride_LSE},
|
||||
hw_info,
|
||||
options.split_kv,
|
||||
options.is_var_split_kv ? block_split_kv.get() : nullptr
|
||||
options.is_var_split_kv ? block_split_kv.get() : nullptr,
|
||||
options.is_fused_reduction
|
||||
};
|
||||
if (options.split_kv < 0 && !options.is_var_split_kv) {
|
||||
Operation::set_split_kv(arguments);
|
||||
@ -724,13 +735,17 @@ void run_mla(Options const & options, cutlass::KernelHardwareInfo const& hw_info
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
if (!options.is_fused_reduction || options.split_kv == 1) {
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
}
|
||||
#elif FP16
|
||||
name += " fp16";
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
if (!options.is_fused_reduction || options.split_kv == 1) {
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -90,6 +90,7 @@ struct Options {
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
bool causal_q_begin = true;
|
||||
bool residual = false;
|
||||
bool varlen = false;
|
||||
bool persistent = false;
|
||||
@ -231,6 +232,8 @@ struct Options {
|
||||
|
||||
std::string mask;
|
||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||
std::string causal_type;
|
||||
cmd.get_cmd_line_argument<std::string>("causal-type", causal_type, "");
|
||||
if (mask == "no" || mask == "") {
|
||||
causal = residual = false;
|
||||
if (varlen) {
|
||||
@ -240,6 +243,11 @@ struct Options {
|
||||
else if (mask == "causal") {
|
||||
residual = false;
|
||||
causal = true;
|
||||
if(causal_type == "qend") {
|
||||
causal_q_begin = false;
|
||||
} else {
|
||||
causal_q_begin = true;
|
||||
}
|
||||
}
|
||||
else if (mask == "residual") {
|
||||
residual = true;
|
||||
@ -279,6 +287,7 @@ struct Options {
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||
<< " --causal-type=<qbegin|qend> Causal mask type\n"
|
||||
<< " --persistent Enables persistent scheduler\n"
|
||||
<< " --varlen Enables variable sequence length\n"
|
||||
<< " B*Q and B*K become the total sequence length\n"
|
||||
@ -1013,7 +1022,11 @@ int main_single(int argc, char const **args) {
|
||||
|
||||
auto with_mask = [&](auto fn) {
|
||||
if (options.causal) {
|
||||
fn(CausalMask<false>{});
|
||||
if(options.causal_q_begin) {
|
||||
fn(CausalMask{});
|
||||
} else {
|
||||
fn(CausalMask<false>{});
|
||||
}
|
||||
}
|
||||
else if (options.residual) {
|
||||
fn(ResidualMask{});
|
||||
|
||||
@ -59,6 +59,14 @@ set(TEST_VARLEN_11 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2
|
||||
set(TEST_VARLEN_12 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
|
||||
set(TEST_VARLEN_13 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
|
||||
set(TEST_VARLEN_14 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
|
||||
set(TEST_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
|
||||
set(TEST_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
|
||||
set(TEST_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
|
||||
set(TEST_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
|
||||
|
||||
|
||||
|
||||
set(TEST_MLA_FWD_VARLEN_00 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_01 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
@ -75,6 +83,13 @@ set(TEST_MLA_FWD_VARLEN_11 --verify --varlen --mask=causal,residual --dl=128 --d
|
||||
set(TEST_MLA_FWD_VARLEN_12 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=177:766 --varlen-k=257:845)
|
||||
set(TEST_MLA_FWD_VARLEN_13 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=177:0:479 --varlen-k=257:0:766)
|
||||
set(TEST_MLA_FWD_VARLEN_14 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
|
||||
set(TEST_MLA_FWD_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
|
||||
set(TEST_MLA_FWD_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
|
||||
set(TEST_MLA_FWD_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
|
||||
set(TEST_MLA_FWD_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
|
||||
|
||||
|
||||
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
|
||||
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
|
||||
@ -87,6 +102,9 @@ set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify)
|
||||
set(TEST_BWD_MLA_BASIC --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=no)
|
||||
set(TEST_BWD_MLA_VARLEN --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=residual --varlen)
|
||||
|
||||
set(TEST_MLA_SEP_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --verify)
|
||||
set(TEST_MLA_FUSE_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --fuse_reduction --verify)
|
||||
|
||||
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
|
||||
|
||||
foreach(PREC fp8 fp16)
|
||||
@ -116,6 +134,12 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
TEST_VARLEN_12
|
||||
TEST_VARLEN_13
|
||||
TEST_VARLEN_14
|
||||
TEST_VARLEN_15
|
||||
TEST_VARLEN_16
|
||||
TEST_VARLEN_17
|
||||
TEST_VARLEN_18
|
||||
TEST_VARLEN_19
|
||||
TEST_VARLEN_20
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
|
||||
@ -139,6 +163,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
TEST_MLA_SEP_REDUCTION
|
||||
TEST_MLA_FUSE_REDUCTION
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO})
|
||||
@ -149,6 +175,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
TEST_MLA_SEP_REDUCTION
|
||||
TEST_MLA_FUSE_REDUCTION
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
|
||||
@ -207,6 +235,12 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
TEST_MLA_FWD_VARLEN_12
|
||||
TEST_MLA_FWD_VARLEN_13
|
||||
TEST_MLA_FWD_VARLEN_14
|
||||
TEST_MLA_FWD_VARLEN_15
|
||||
TEST_MLA_FWD_VARLEN_16
|
||||
TEST_MLA_FWD_VARLEN_17
|
||||
TEST_MLA_FWD_VARLEN_18
|
||||
TEST_MLA_FWD_VARLEN_19
|
||||
TEST_MLA_FWD_VARLEN_20
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_fwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_fwd_${PREC} PRIVATE ${PREC_MACRO})
|
||||
|
||||
@ -203,13 +203,12 @@ struct CausalMask : NoMask {
|
||||
|
||||
// See note below on different ways to think about causal attention
|
||||
// Again, we'd add the offset_q into the max_blocks_q calculation
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
if constexpr (IsQBegin) {
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
} else {
|
||||
const int offset_q = get<1>(problem_size) - get<0>(problem_size);
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
}
|
||||
@ -222,12 +221,12 @@ struct CausalMask : NoMask {
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
if constexpr (IsQBegin) {
|
||||
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
|
||||
} else {
|
||||
const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape);
|
||||
return ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape));
|
||||
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
|
||||
}
|
||||
}
|
||||
|
||||
@ -277,9 +276,10 @@ struct CausalMask : NoMask {
|
||||
}
|
||||
};
|
||||
|
||||
struct CausalForBackwardMask : CausalMask<true>, ResidualMaskForBackward {
|
||||
template<bool kIsQBegin = true>
|
||||
struct CausalForBackwardMask : CausalMask<kIsQBegin>, ResidualMaskForBackward {
|
||||
|
||||
using Base = CausalMask<true>;
|
||||
using Base = CausalMask<kIsQBegin>;
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
@ -296,10 +296,15 @@ struct CausalForBackwardMask : CausalMask<true>, ResidualMaskForBackward {
|
||||
// where we only compute the next row and use cache for the rest
|
||||
// - if you'd like this, you only need to add an offset like so:
|
||||
// get<0>(pos) + offset_q < get<1>(pos)
|
||||
int offset_q = 0;
|
||||
if constexpr (!kIsQBegin) {
|
||||
offset_q = get<1>(problem_size) - get<0>(problem_size);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
bool masked = (get<0>(pos) < get<1>(pos)) || !elem_less(pos, problem_size);
|
||||
bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size);
|
||||
if (masked) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
|
||||
@ -100,7 +100,7 @@ public:
|
||||
cutlass::fmha::kernel::FmhaKernelBwdConvert<ProblemShape, Element, ElementAccumulator>
|
||||
>;
|
||||
|
||||
using OperationNormal= cutlass::fmha::device::FMHA<
|
||||
using OperationMha= cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<
|
||||
ProblemShape, Element, ElementAccumulator, TileShape, Mask
|
||||
>
|
||||
@ -112,7 +112,7 @@ public:
|
||||
>
|
||||
>;
|
||||
|
||||
using Operation = std::conditional_t<IsMla, OperationMla, OperationNormal>;
|
||||
using Operation = std::conditional_t<IsMla, OperationMla, OperationMha>;
|
||||
|
||||
using Kernel = typename Operation::Kernel;
|
||||
|
||||
|
||||
@ -127,7 +127,11 @@ public:
|
||||
int waves = ceil_div(B * split_heur, sm_count);
|
||||
int k_waves = ceil_div(max_splits, split_heur);
|
||||
int split_wave_aware = ceil_div(max_splits, k_waves);
|
||||
args.split_kv = split_wave_aware;
|
||||
if (args.is_fused_reduction && split_wave_aware > 1) {
|
||||
args.split_kv = std::min(split_wave_aware, static_cast<int>(sm_count/2));
|
||||
} else {
|
||||
args.split_kv = split_wave_aware;
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
@ -273,11 +277,33 @@ public:
|
||||
CUTLASS_TRACE_HOST("MLA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
|
||||
auto [H, K, D, B] = params.fmha_params.problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
if (params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) {
|
||||
auto result = cudaMemsetAsync(params.fmha_params.epilogue.ptr_o, 0, sizeof(typename Kernel::ElementOut) * H * D_latent * B, stream);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaMemsetAsync() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
auto total_bytes = H * B * (sizeof(int) + sizeof(typename Kernel::ElementLSE)) + 2 * B * sizeof(int);
|
||||
uint8_t* ws = reinterpret_cast<uint8_t*>(params.fmha_params.epilogue.ptr_lse_exchange_buff);
|
||||
result = cudaMemsetAsync(ws, 0, total_bytes, stream);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaMemsetAsync() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;;
|
||||
}
|
||||
}
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
@ -298,7 +324,7 @@ public:
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
if (params.reduction_params.split_kv > 1) {
|
||||
if (!params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) {
|
||||
// launch reduction kernel
|
||||
dim3 const block = ReductionKernel::get_block_shape();
|
||||
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
|
||||
|
||||
@ -1245,9 +1245,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
};
|
||||
|
||||
bool leading_causal_masking = false;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|
||||
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
|
||||
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
|
||||
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
int offset = get<1>(problem_shape) - get<0>(problem_shape);
|
||||
int kv_left = get<1>(blk_coord) * TileShapeK{};
|
||||
int kv_right = kv_left + TileShapeK{} - 1;
|
||||
int q_left = iter_index * TileShapeQ{} + offset;
|
||||
int q_right = q_left + TileShapeQ{} - 1;
|
||||
|
||||
leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));
|
||||
}
|
||||
bool trailing_residual_masking = false;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
|
||||
@ -1683,9 +1690,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
);
|
||||
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
|
||||
int iter_start = 0;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask> ||
|
||||
std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
|
||||
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
|
||||
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
int offset = get<1>(problem_shape) - get<0>(problem_shape);
|
||||
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
|
||||
}
|
||||
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
|
||||
return;
|
||||
|
||||
@ -1230,9 +1230,16 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
|
||||
};
|
||||
|
||||
bool leading_causal_masking = false;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|
||||
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
|
||||
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
|
||||
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
int offset = get<1>(problem_shape) - get<0>(problem_shape);
|
||||
int kv_left = get<1>(blk_coord) * TileShapeK{};
|
||||
int kv_right = kv_left + TileShapeK{} - 1;
|
||||
int q_left = iter_index * TileShapeQ{} + offset;
|
||||
int q_right = q_left + TileShapeQ{} - 1;
|
||||
|
||||
leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));
|
||||
}
|
||||
bool trailing_residual_masking = false;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
|
||||
@ -1677,9 +1684,11 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
|
||||
);
|
||||
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
|
||||
int iter_start = 0;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|
||||
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
|
||||
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
|
||||
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
int offset = get<1>(problem_shape) - get<0>(problem_shape);
|
||||
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
|
||||
}
|
||||
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
|
||||
return;
|
||||
|
||||
@ -101,8 +101,17 @@ struct Sm100FmhaMlaReductionKernel {
|
||||
|
||||
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
|
||||
if (params.split_kv <= 1) return;
|
||||
|
||||
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
|
||||
|
||||
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
|
||||
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
|
||||
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
|
||||
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
|
||||
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
|
||||
if (local_split_kv == 1) return;
|
||||
|
||||
__shared__ ElementAcc sLseScale[kMaxSplits];
|
||||
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
|
||||
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
|
||||
@ -113,12 +122,6 @@ struct Sm100FmhaMlaReductionKernel {
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
|
||||
Shape<_1>{}, Stride<_1>{});
|
||||
|
||||
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
|
||||
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
|
||||
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
|
||||
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
|
||||
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
if (warp_idx == 0) {
|
||||
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
||||
@ -130,17 +133,18 @@ struct Sm100FmhaMlaReductionKernel {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
|
||||
}
|
||||
|
||||
|
||||
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
lse_max = max(lse_max, local_lse[i]);
|
||||
lse_max = fmax(local_lse[i], lse_max);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
|
||||
lse_max = fmax(__shfl_xor_sync(0xffffffff, lse_max, offset), lse_max);
|
||||
}
|
||||
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
|
||||
|
||||
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
|
||||
|
||||
ElementAcc sum_lse = 0;
|
||||
|
||||
@ -36,6 +36,7 @@
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/arch/simd_sm100.hpp"
|
||||
#include "cutlass/barrier.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
@ -44,6 +45,7 @@
|
||||
|
||||
#include "gather_tensor.hpp" // from examples/common
|
||||
#include "common/pow_2.hpp"
|
||||
#include "sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
@ -87,8 +89,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
using TileShapeR = tuple_element_t<1, TileShapeD>;
|
||||
static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim");
|
||||
|
||||
using ProblemShape = Shape<TileShapeH, int, TileShapeD, int>;
|
||||
using TensorStride = Stride<int64_t, _1, int64_t>;
|
||||
using ProblemShape = Shape<TileShapeH, int, TileShapeD, int>;
|
||||
using TensorStride = Stride<int64_t, _1, int64_t>;
|
||||
using TmemAllocator = cute::conditional_t<kIs2Sm, cute::TMEM::Allocator2Sm, cute::TMEM::Allocator1Sm>;
|
||||
|
||||
static_assert(TileShapeH{} == 128);
|
||||
@ -181,10 +183,13 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB;
|
||||
using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB;
|
||||
using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, make_shape(Int<IterationsPV_K>{}, _2{})));
|
||||
using SmemLayoutOut = decltype(take<0,2>(typename CollectiveMmaQK::CtaShape_MNK{}));
|
||||
using TileShapeAcc = decltype(take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}));
|
||||
|
||||
static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
|
||||
static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutKC{})) * cute::sizeof_bits_v<Element>);
|
||||
static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutVC{})) * cute::sizeof_bits_v<Element>);
|
||||
|
||||
// pre-condition for overlapped smem staging
|
||||
static_assert(kBytesLoadKC == kBytesLoadVC);
|
||||
static_assert(StagesQK == StagesPV);
|
||||
@ -226,7 +231,10 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutKC>> smem_kc;
|
||||
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutVC>> smem_vc;
|
||||
};
|
||||
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutP>> smem_p;
|
||||
union {
|
||||
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutP>> smem_p;
|
||||
alignas(2048) cute::array<ElementOut, size(TileShapeAcc{})> smem_acc;
|
||||
};
|
||||
};
|
||||
|
||||
struct SharedStorage {
|
||||
@ -280,6 +288,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
KernelHardwareInfo hw_info;
|
||||
int split_kv = -1;
|
||||
int* ptr_split_kv = nullptr;
|
||||
bool is_fused_reduction = false;
|
||||
};
|
||||
|
||||
using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A;
|
||||
@ -288,6 +297,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B;
|
||||
using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B;
|
||||
|
||||
using GmemLayout = decltype(make_layout(Shape<int,int,int>{}, Stride<int64_t, _1, int64_t>{}));
|
||||
using SmemLayout = decltype(make_layout(TileShapeAcc{}, LayoutRight{}));
|
||||
|
||||
using TmaReduceSum = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},
|
||||
make_tensor(recast_ptr<ElementOut>(nullptr), GmemLayout{}), SmemLayout{}));
|
||||
|
||||
struct MainloopParams {
|
||||
TmaLoadQLatent tma_load_q_latent;
|
||||
TmaLoadQRope tma_load_q_rope;
|
||||
@ -306,6 +321,10 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
Stride<_1, int> stride_lse;
|
||||
Stride<_1, int> stride_lse_acc;
|
||||
ElementAcc output_scale = 1.0f;
|
||||
ElementLSE* ptr_lse_exchange_buff = nullptr;
|
||||
int* ptr_lse_max_exchange_buff = nullptr;
|
||||
int* ptr_lock = nullptr; // semaphore
|
||||
TmaReduceSum tma_reduce_sum;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
@ -316,6 +335,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
typename TileScheduler::Params tile_scheduler;
|
||||
int split_kv = -1;
|
||||
int* ptr_split_kv = nullptr;
|
||||
bool is_fused_reduction = false;
|
||||
};
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
@ -380,11 +400,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
epilogue_params.ptr_o = args.epilogue.ptr_o;
|
||||
epilogue_params.stride_o = args.epilogue.stride_o;
|
||||
epilogue_params.ptr_lse = args.epilogue.ptr_lse;
|
||||
epilogue_params.ptr_lse = args.epilogue.ptr_lse;
|
||||
epilogue_params.stride_lse = args.epilogue.stride_lse;
|
||||
epilogue_params.output_scale = args.epilogue.output_scale;
|
||||
epilogue_params.tma_reduce_sum = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(recast_ptr<ElementOut>(args.epilogue.ptr_o), make_layout(make_shape(H, L, B), args.epilogue.stride_o)), SmemLayout{});
|
||||
|
||||
if (args.split_kv > 1) {
|
||||
if (!args.is_fused_reduction && args.split_kv > 1) {
|
||||
ElementAcc* ptr_o_acc = reinterpret_cast<ElementAcc*>(workspace);
|
||||
ElementLSE* ptr_lse_acc = reinterpret_cast<ElementLSE*>(ptr_o_acc + H * L * args.split_kv * B);
|
||||
epilogue_params.ptr_o_acc = ptr_o_acc;
|
||||
@ -392,10 +413,18 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
epilogue_params.stride_o_acc = make_tuple(static_cast<int64_t>(0 + L) * args.split_kv, _1{}, static_cast<int64_t>(0 + H * L) * args.split_kv);
|
||||
epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv);
|
||||
} else if (args.is_fused_reduction && args.split_kv > 1) {
|
||||
ElementLSE* ptr_lse_exchange_buff = reinterpret_cast<ElementLSE*>(workspace);
|
||||
epilogue_params.ptr_lse_exchange_buff = ptr_lse_exchange_buff;
|
||||
int* ptr_lse_max_exchange_buff = reinterpret_cast<int*>(ptr_lse_exchange_buff + H * B);
|
||||
epilogue_params.ptr_lse_max_exchange_buff = ptr_lse_max_exchange_buff;
|
||||
int* ptr_lock = ptr_lse_max_exchange_buff + H * B;
|
||||
epilogue_params.ptr_lock = ptr_lock;
|
||||
}
|
||||
|
||||
return {args.problem_shape, args.mainloop, epilogue_params, mainloop_params,
|
||||
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), args.split_kv, args.ptr_split_kv};
|
||||
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv),
|
||||
args.split_kv, args.ptr_split_kv, args.is_fused_reduction};
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) {
|
||||
@ -403,10 +432,29 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
auto split_kv = args.split_kv;
|
||||
return (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B;
|
||||
size_t workspace_size {0};
|
||||
if (args.is_fused_reduction && args.split_kv > 1) {
|
||||
// one exchange buffer for LSE max and another buffer for total LSE
|
||||
// two locks per batch, frist lock is for CTA0 / H=0..63 and the second is for CTA1 / H=64..127
|
||||
workspace_size = H * B * (sizeof(int) + sizeof(ElementLSE)) + 2 * B * sizeof(int);
|
||||
} else if (!args.is_fused_reduction && args.split_kv > 1) {
|
||||
workspace_size = (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B;
|
||||
}
|
||||
return workspace_size;
|
||||
}
|
||||
static Status initialize_workspace(
|
||||
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
|
||||
Arguments const& args, void* ws, cudaStream_t stream) {
|
||||
auto workspace_size = get_workspace_size(args);
|
||||
if (args.is_fused_reduction && args.split_kv > 1) {
|
||||
auto result = cudaMemsetAsync(ws, 0, workspace_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaMemsetAsync() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;;
|
||||
}
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@ -448,6 +496,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
std::cerr << __FILE__ << "(" << __LINE__ << "): split-k off\n";
|
||||
return false;
|
||||
}
|
||||
if (args.is_fused_reduction && args.split_kv > 1) {
|
||||
if (2 * args.split_kv > args.hw_info.sm_count ||
|
||||
std::is_same_v<TileScheduler, Sm100MlaIndividualTileScheduler>) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -746,7 +800,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
pipeline_mma_s, pipeline_mma_s_consumer_state,
|
||||
pipeline_p_mma, pipeline_p_mma_producer_state,
|
||||
pipeline_mma_o, pipeline_mma_o_consumer_state,
|
||||
local_split_kv
|
||||
local_split_kv,
|
||||
params.is_fused_reduction
|
||||
);
|
||||
}
|
||||
|
||||
@ -1777,7 +1832,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
if (epilogue_args.ptr_o_acc != nullptr) {
|
||||
|
||||
if (split_kv > 1) {
|
||||
using ElementOutAcc = ElementAcc;
|
||||
constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v<ElementOutAcc>;
|
||||
Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), make_shape(H, D_latent, B), epilogue_args.stride_o_acc);
|
||||
@ -1806,16 +1862,20 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
copy(tTR_rO_src, tR2G_rO_dst);
|
||||
|
||||
// compute LSE
|
||||
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
if (get<1>(cta_coord) == 0) {
|
||||
if (epilogue_args.ptr_lse != nullptr) {
|
||||
// compute LSE
|
||||
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
|
||||
// store LSE
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc);
|
||||
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
// for 2x2 dp, this must be conditional and the index is wrong
|
||||
if (! kIs2Sm || (threadIdx.x < 64))
|
||||
{
|
||||
gLSE(threadIdx.x) = lse;
|
||||
// store LSE
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc);
|
||||
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
// for 2x2 dp, this must be conditional and the index is wrong
|
||||
if (! kIs2Sm || (threadIdx.x < 64))
|
||||
{
|
||||
gLSE(threadIdx.x) = lse;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
@ -1845,24 +1905,165 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
copy(tTR_rO_src, tR2G_rO_dst);
|
||||
|
||||
if (get<1>(cta_coord) == 0) {
|
||||
if (epilogue_args.ptr_lse != nullptr) {
|
||||
// compute LSE
|
||||
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
|
||||
if (epilogue_args.ptr_lse != nullptr) {
|
||||
// compute LSE
|
||||
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
// store LSE
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse);
|
||||
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
|
||||
// store LSE
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse);
|
||||
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
|
||||
// for 2x2 dp, this must be conditional and the index is wrong
|
||||
if (! kIs2Sm || (threadIdx.x < 64))
|
||||
{
|
||||
gLSE(threadIdx.x) = lse;
|
||||
// for 2x2 dp, this must be conditional and the index is wrong
|
||||
if (! kIs2Sm || (threadIdx.x < 64))
|
||||
{
|
||||
gLSE(threadIdx.x) = lse;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord>
|
||||
CUTLASS_DEVICE ElementLSE epilogue_lse_reduction(
|
||||
ElementAcc& row_max,
|
||||
ElementAcc& row_sum,
|
||||
BlkCoord const& cta_coord,
|
||||
ProblemShape const& problem_shape,
|
||||
MainloopArguments const& mainloop_args,
|
||||
EpilogueParams const& epilogue_args,
|
||||
int const& local_split_kv) {
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{});
|
||||
|
||||
constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp;
|
||||
using Sync = cutlass::detail::NamedBarrierSync<kNumThreads, kNamedBarrierExchange>;
|
||||
|
||||
auto wait = [](int* lock, int count) {
|
||||
__threadfence();
|
||||
if (threadIdx.x == 0) {
|
||||
atomicAdd(lock, 1);
|
||||
while (atomicCAS(lock, count, count) != count) {};
|
||||
}
|
||||
__threadfence();
|
||||
Sync::sync();
|
||||
};
|
||||
|
||||
const ElementLSE lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
Tensor mLSE_max_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_max_exchange_buff), make_shape(H, B), epilogue_args.stride_lse);
|
||||
Tensor gLSE_max_buff = local_tile(mLSE_max_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
|
||||
int* local_lock = epilogue_args.ptr_lock + get<0>(cta_coord) + 2 * get<2>(cta_coord);
|
||||
|
||||
if (! kIs2Sm || (threadIdx.x < 64)) {
|
||||
atomicMax(&(gLSE_max_buff(threadIdx.x)), __float2int_rn(lse));
|
||||
}
|
||||
wait(local_lock, local_split_kv);
|
||||
|
||||
auto global_lse_max = static_cast<ElementLSE>(gLSE_max_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x));
|
||||
|
||||
Tensor mLSE_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_exchange_buff), make_shape(H, B), epilogue_args.stride_lse);
|
||||
Tensor gLSE_buff = local_tile(mLSE_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
|
||||
if (! kIs2Sm || (threadIdx.x < 64)) {
|
||||
atomicAdd(&(gLSE_buff(threadIdx.x)), expf(lse - global_lse_max));
|
||||
}
|
||||
wait(local_lock, 2*local_split_kv);
|
||||
|
||||
const auto sum_lse = gLSE_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x);
|
||||
const auto global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementLSE>::infinity() :
|
||||
cutlass::fast_log(sum_lse) + global_lse_max;
|
||||
const auto lse_scale = expf(lse - global_lse);
|
||||
|
||||
if (epilogue_args.ptr_lse != nullptr) {
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse);
|
||||
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
|
||||
// write out the global LSE
|
||||
if (! kIs2Sm || (threadIdx.x < 64)) {
|
||||
gLSE(threadIdx.x) = global_lse;
|
||||
}
|
||||
}
|
||||
return lse_scale;
|
||||
}
|
||||
|
||||
|
||||
template<class BlkCoord>
|
||||
CUTLASS_DEVICE void epilogue_reduction(
|
||||
ElementAcc& row_max,
|
||||
ElementAcc& row_sum,
|
||||
BlkCoord const& blk_coord,
|
||||
ProblemShape const& problem_shape,
|
||||
MainloopArguments const& mainloop_args,
|
||||
EpilogueParams const& epilogue_args,
|
||||
TensorStorage& shared_tensors,
|
||||
int const& local_split_kv,
|
||||
ElementLSE const& lse_scale) {
|
||||
|
||||
constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp;
|
||||
using Sync = cutlass::detail::NamedBarrierSync<kNumThreads, kNamedBarrierExchange>;
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{};
|
||||
|
||||
TiledMmaPV tiled_mma_pv;
|
||||
Tensor tOtO = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{})));
|
||||
|
||||
CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{});
|
||||
CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{});
|
||||
|
||||
using EpilogueLinearCombination = cutlass::epilogue::thread::LinearCombination<ElementOut, 1, ElementAcc, ElementAcc, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
|
||||
EpilogueLinearCombination epilogue_op({epilogue_args.output_scale / row_sum * lse_scale});
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int k = 0; k < IterationsPV_N; ++k) {
|
||||
auto cta_coord = replace<1>(blk_coord, k);
|
||||
|
||||
uint32_t tmem_o = uint32_t(TmemAllocation::kO0) + k * uint32_t(TmemAllocation::kSizeAccO);
|
||||
tOtO.data() = tmem_o;
|
||||
|
||||
Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{});
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o);
|
||||
Tensor gO = local_tile(mO, TileShapeAcc{}, take<0,3>(cta_coord));
|
||||
|
||||
auto tiled_t2r = make_tmem_copy(load_op, tAcc);
|
||||
auto thread_idx = threadIdx.x % size(tiled_t2r);
|
||||
|
||||
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
|
||||
Tensor tTR_gO = thread_t2r.partition_D(gO);
|
||||
Tensor tTR_rAcc = make_tensor<ElementAcc>(shape(tTR_gO));
|
||||
Tensor tTR_tAcc = thread_t2r.partition_S(tAcc);
|
||||
|
||||
copy(tiled_t2r, tTR_tAcc, tTR_rAcc);
|
||||
|
||||
Tensor sO = make_tensor(make_smem_ptr(reinterpret_cast<ElementOut*>(shared_tensors.smem_acc.begin())), SmemLayout{});
|
||||
Tensor tTR_sO = thread_t2r.partition_D(sO);
|
||||
|
||||
Sync::sync();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(tTR_rAcc); i++) {
|
||||
tTR_sO(i) = epilogue_op(tTR_rAcc(i));
|
||||
}
|
||||
tma_store_fence();
|
||||
Sync::sync();
|
||||
|
||||
auto tma_reduce_sum_per_cta = epilogue_args.tma_reduce_sum.get_slice(_0{});
|
||||
auto gmem_tensor_coord = epilogue_args.tma_reduce_sum.get_tma_tensor(shape(mO));
|
||||
auto gmem_tensor_coord_per_cta = local_tile(gmem_tensor_coord, TileShapeAcc{}, take<0,3>(cta_coord));
|
||||
if (threadIdx.x % kNumThreads == 0) {
|
||||
copy(epilogue_args.tma_reduce_sum,
|
||||
tma_reduce_sum_per_cta.partition_S(sO),
|
||||
tma_reduce_sum_per_cta.partition_D(gmem_tensor_coord_per_cta));
|
||||
tma_store_arrive();
|
||||
}
|
||||
tma_store_wait<0>();
|
||||
}
|
||||
}
|
||||
|
||||
template<class CtaCoord>
|
||||
CUTLASS_DEVICE void compute(
|
||||
@ -1877,7 +2078,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
typename PipelineP::PipelineState& pipeline_p_mma_producer_state,
|
||||
PipelineO& pipeline_mma_o,
|
||||
typename PipelineO::PipelineState& pipeline_mma_o_consumer_state,
|
||||
int const& split_kv) {
|
||||
int const& split_kv,
|
||||
bool const& is_fused_reduction) {
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
|
||||
@ -1987,17 +2189,38 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive();
|
||||
|
||||
// epilogue
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < IterationsPV_N; j++) {
|
||||
epilogue(
|
||||
row_max, row_sum,
|
||||
replace<1>(cta_coord, j), problem_shape,
|
||||
mainloop_args, epilogue_args, shared_tensors,
|
||||
uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), split_kv
|
||||
const int actual_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
if (!is_fused_reduction || actual_split_kv == 1) {
|
||||
// epilogue
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < IterationsPV_N; j++) {
|
||||
epilogue(
|
||||
row_max, row_sum,
|
||||
replace<1>(cta_coord, j), problem_shape,
|
||||
mainloop_args, epilogue_args, shared_tensors,
|
||||
uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO),
|
||||
actual_split_kv
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const ElementLSE lse_scale =
|
||||
epilogue_lse_reduction(
|
||||
row_max, row_sum,
|
||||
cta_coord,
|
||||
problem_shape,
|
||||
mainloop_args, epilogue_args,
|
||||
actual_split_kv
|
||||
);
|
||||
|
||||
epilogue_reduction(row_max, row_sum,
|
||||
cta_coord,
|
||||
problem_shape,
|
||||
mainloop_args, epilogue_args,
|
||||
shared_tensors,
|
||||
actual_split_kv,
|
||||
lse_scale
|
||||
);
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state);
|
||||
++pipeline_mma_o_consumer_state;
|
||||
|
||||
@ -142,8 +142,8 @@ struct Sm100MlaPersistentTileScheduler {
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, n_split_kv;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
return make_coord(m_block, _0{}, bidb, n_split_kv);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user