SM100 Cutlass MLA decode with unrestricted num_heads (< 128) for DeepSeek TP (#20769)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
committed by
GitHub
parent
61e20828da
commit
8cdc371217
@ -514,6 +514,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor page_table, float scale) -> ()");
|
||||
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||
|
||||
// SM100 CUTLASS MLA decode
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
|
||||
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
|
||||
" Tensor page_table, Tensor workspace, float "
|
||||
"scale,"
|
||||
" int num_kv_splits) -> ()");
|
||||
ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode);
|
||||
|
||||
// SM100 CUTLASS MLA workspace
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
|
||||
" int sm_count, int num_kv_splits) "
|
||||
"-> int");
|
||||
ops.impl("sm100_cutlass_mla_get_workspace_size",
|
||||
&sm100_cutlass_mla_get_workspace_size);
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
||||
|
||||
Reference in New Issue
Block a user