From 314fa8abbf9d4f6dc89eba1d8fdf80e6cd4432ed Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 16 Oct 2025 09:36:09 -0400 Subject: [PATCH] [Attention] Tune CUTLASS MLA num_splits (#26846) Signed-off-by: Matthew Bonanni --- .../cutlass_sm100_mla/device/sm100_mla.hpp | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp index 297d94dcc0..2d4b4a67d2 100644 --- a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp +++ b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp @@ -125,32 +125,37 @@ public: } static void set_split_kv (KernelArguments& args) { - // printf("set_split_kv start"); if (args.split_kv >= 1) return; auto [H, K, D, B] = args.problem_shape; - // std::cout << H << " " << K << " " << D << " " << B << "\n"; int sm_count = args.hw_info.sm_count; - // printf(" sm_count = %d\n", sm_count); - int max_splits = ceil_div(K, 128); - max_splits = min(16, max_splits); + float seq_length_k = static_cast(K) / 1024.0f; + int max_splits = 1; - // TODO: This avoids a hang when the batch size larger than 1 and - // there is more than 1 kv_splits. - // Discuss with NVIDIA how this can be fixed. - if (B > 1) { - max_splits = min(1, max_splits); + if (B <= 4 && seq_length_k >= 16) { + max_splits = 16; } - - // printf(" max_splits = %d\n", max_splits); + else if (B <= 8 && seq_length_k >= 4) { + max_splits = 8; + } + else if ((B <= 16 && seq_length_k >= 8) || + (B == 48 && seq_length_k >= 32)) { + max_splits = 4; + } + else if ((B <= 32 && seq_length_k >= 16) || + (B == 96 && seq_length_k >= 16)) { + max_splits = 2; + } + else { + max_splits = 1; + } + + // Wave-aware scheduling: ensure integer number of waves in K dimension int sms_per_batch = max(1, sm_count / B); - // printf(" sms_per_batch = %d\n", sms_per_batch); int split_heur = min(max_splits, sms_per_batch); int waves = ceil_div(B * split_heur, sm_count); int k_waves = ceil_div(max_splits, split_heur); int split_wave_aware = ceil_div(max_splits, k_waves); args.split_kv = split_wave_aware; - // printf(" args.split_kv = %d\n", args.split_kv); - } /// Determines whether the GEMM can execute the given problem.