SM100 Cutlass MLA decode with unrestricted num_heads (< 128) for DeepSeek TP (#20769)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev
2025-07-14 21:06:38 -04:00
committed by GitHub
parent 61e20828da
commit 8cdc371217
12 changed files with 3283 additions and 2 deletions

View File

@ -514,6 +514,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor page_table, float scale) -> ()");
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// SM100 CUTLASS MLA decode
ops.def(
"sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
" Tensor page_table, Tensor workspace, float "
"scale,"
" int num_kv_splits) -> ()");
ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode);
// SM100 CUTLASS MLA workspace
ops.def(
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
" int sm_count, int num_kv_splits) "
"-> int");
ops.impl("sm100_cutlass_mla_get_workspace_size",
&sm100_cutlass_mla_get_workspace_size);
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor! output, Tensor input,"