Migrate FlashMLA codes to example. (#2135)

This commit is contained in:
Junkai-Wu
2025-02-26 14:29:07 +08:00
committed by GitHub
parent af5519d938
commit 15f5468872
19 changed files with 627 additions and 878 deletions

View File

@ -1,542 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A Hopper CUTLASS example for Flash MLA.
*/
#include <cassert>
#include <iostream>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include <thrust/universal_vector.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "cutlass/util/command_line.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include <cuda_runtime.h>
#include "flash_fwd_mla_kernel.h"
#include "flash_mla.h"
#include "fill_nan.h"
#include "transform.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
if (error != cudaSuccess) { \
std::cerr << "CUDA error: " << cudaGetErrorString(error) << " at " << \
__FILE__ << ":" << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
int iterations;
int b, s, h_q, s_q;
int h_kv, d, dv;
float softmax_scale;
bool varlen;
bool causal;
static constexpr int block_size = 64;
Options():
help(false),
b(128), s(4096), h_q(16), s_q(1),
h_kv(1), d(576), dv(512),
varlen(false),
causal(true),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
Options defaults;
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("b", b, defaults.b);
cmd.get_cmd_line_argument("s", s, defaults.s);
cmd.get_cmd_line_argument("h_q", h_q, defaults.h_q);
cmd.get_cmd_line_argument("s_q", s_q, defaults.s_q);
cmd.get_cmd_line_argument("h_kv", h_kv, defaults.h_kv);
cmd.get_cmd_line_argument("d", d, defaults.d);
cmd.get_cmd_line_argument("dv", dv, defaults.dv);
if (cmd.check_cmd_line_flag("varlen")) {
varlen = true;
}
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
softmax_scale = 1 / std::sqrt(d);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "79_hopper_flash_mla\n\n"
<< " Hopper Flash MLA kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --b=<int> Sets the batch size\n"
<< " --s=<int> Sets the sequence length\n"
<< " --h_q=<int> Sets the number of heads\n"
<< " --s_q=<int> Sets the sequence length of the query\n"
<< " --varlen Sets the varlen as true or false\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
return out;
}
/// TOOD:Compute performance in GFLOP
};
/// Helper to initialize a block of device data
template <typename Element>
static void
initialize_values(
thrust::universal_vector<Element>& dst_ptr,
cutlass::Distribution::Kind dist_kind,
uint64_t seed,
Element var = Element(1.f)) {
if (cutlass::Distribution::Uniform == dist_kind) {
int scope = 2;
cutlass::reference::host::BlockFillRandomUniform(
dst_ptr.data().get(), dst_ptr.size(), seed, scope, -scope, 0);
}
else if (cutlass::Distribution::AllZeros == dist_kind) {
cutlass::reference::host::BlockFillRandomUniform(
dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0, 0);
}
else if (cutlass::Distribution::AllOnes == dist_kind) {
cutlass::reference::host::BlockFillRandomUniform(
dst_ptr.data().get(), dst_ptr.size(), seed, 1, 1, 0);
}
else if (cutlass::Distribution::Gaussian == dist_kind) {
cutlass::reference::device::BlockFillRandomGaussian(
dst_ptr.data().get(), dst_ptr.size(), seed, (Element) 0, var);
}
else if (cutlass::Distribution::Sequential == dist_kind) {
cutlass::reference::host::BlockFillSequential(dst_ptr.data().get(), dst_ptr.size());
}
else {
std::cerr << "Invalid distribution kind!\n.";
exit(1);
}
}
void initialize_varlen(thrust::universal_vector<int32_t>& block_C, const Options &options) {
block_C.resize(options.b);
std::vector<int32_t> cache_seqlens(options.b, options.s);
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> distribution(options.s, options.s / 2.0f);
for (int i = 0; i < options.b; ++i) {
if (options.varlen) {
float random_length = distribution(gen);
cache_seqlens[i] = std::max(static_cast<int32_t>(random_length), options.s_q);
} else {
cache_seqlens[i] = options.s;
}
}
cutlass::DeviceAllocation<int32_t> d_cache_seqlens(options.b);
CUDA_CHECK(cudaMemcpy(
block_C.data().get(),
cache_seqlens.data(),
options.b * sizeof(int32_t),
cudaMemcpyHostToDevice
));
}
auto initialize_metadata(
thrust::universal_vector<int32_t> &block_C,
thrust::universal_vector<int32_t> &block_MD, thrust::universal_vector<int32_t> &block_S,
int& num_sm_parts,
const Options &options) {
// This should match the logic in the MLA kernel.
static constexpr int block_size_m = 64;
static constexpr int block_size_n = 64;
static constexpr int fixed_overhead_num_blocks = 5;
cudaDeviceProp props;
int current_device;
CUDA_CHECK(cudaGetDevice(&current_device));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device));
auto batch_size = options.b;
int sm_count = props.multiProcessorCount;
num_sm_parts = sm_count / options.h_kv / cutlass::ceil_div(options.h_kv, block_size_m);
block_MD.resize(num_sm_parts * TileSchedulerMetaDataSize);
block_S.resize(options.b + 1);
Mla_metadata_params params{};
params.seqlens_k_ptr = block_C.data().get();
params.tile_scheduler_metadata_ptr = block_MD.data().get();
params.num_splits_ptr = block_S.data().get();
params.batch_size = batch_size;
params.block_size_n = block_size_n;
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
params.num_sm_parts = num_sm_parts;
cudaStream_t stream{nullptr};
get_mla_metadata_func(params, stream);
}
// only transpose the dimensions 2 and 3
template <class Element>
void transpose(
thrust::universal_vector<Element> &block_S,
thrust::universal_vector<Element> &block_D,
cute::tuple<int, int, int, int, int> problem_shape) {
using Operator = cutlass::transform::device::TransformUniversalAdapter<TransposeKernel<Element>>;
cudaError_t result;
result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the Transpose kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
}
typename Operator::Arguments arguments{
block_S.data().get(),
block_D.data().get(),
problem_shape,
};
Operator op;
size_t workspace_size = Operator::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "This kernel is not supported. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return;
}
status = op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return;
}
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return;
}
result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return;
}
}
struct TestBed {
using Element = cutlass::bfloat16_t;
using ElementAcc = float;
thrust::universal_vector<Element> block_Q; // query
thrust::universal_vector<Element> block_Q_T; // query transpose
thrust::universal_vector<Element> block_K; // blocked key
thrust::universal_vector<int32_t> block_T; // block table
thrust::universal_vector<int32_t> block_C; // cache seqlens
// TODO: block_V is not used in the example
// thrust::universal_vector<Element> block_V; // dv
thrust::universal_vector<int32_t> block_MD; // mla metadata
thrust::universal_vector<int32_t> block_S; // num splits
thrust::universal_vector<Element> block_O; // output
thrust::universal_vector<Element> block_LSE; // lse
thrust::universal_vector<Element> block_O_T; // output transpose
thrust::universal_vector<Element> block_LSE_T; // lse transpose
thrust::universal_vector<ElementAcc> block_O_Accum; // output
thrust::universal_vector<ElementAcc> block_LSE_Accum; // lse
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(
const Options &options,
int& total_blocks, int& blocks_per_seq, int& num_sm_parts,
uint64_t seed = 2025) {
initialize_varlen(block_C, options);
thrust::device_ptr<int32_t> d_ptr(block_C.data().get());
int64_t total_seqlens = thrust::reduce(d_ptr, d_ptr + options.b);
float sum = static_cast<float>(total_seqlens);
int32_t mean_seqlens = static_cast<int32_t>(sum / options.b);
int32_t max_seqlen = thrust::reduce(d_ptr, d_ptr + options.b,
0,
thrust::maximum<int32_t>());
int max_seqlen_pad = ((max_seqlen + 255) / 256) * 256;
blocks_per_seq = max_seqlen_pad / options.block_size;
total_blocks = options.b * blocks_per_seq;
// Query: [b, s_q, h_q, d]
block_Q.resize(options.b * options.s_q * options.h_q * options.d);
block_Q_T.resize(options.b * options.s_q * options.h_q * options.d);
// Block table: [b, max_num_blocks_per_seq]
block_T.resize(total_blocks);
// Key: [b, max_num_blocks_per_seq, block_size, h_kv, d]
block_K.resize(total_blocks * options.block_size * options.h_kv * options.d);
initialize_values(block_Q, cutlass::Distribution::Gaussian, seed + 1);
initialize_values(block_T, cutlass::Distribution::Sequential, seed + 3);
initialize_values(block_K, cutlass::Distribution::Gaussian, seed + 5);
// Set the exceeding part to NaN
fill_nan(block_K.data().get(), block_C.data().get(),
options.b, max_seqlen_pad, options.h_kv, options.d);
initialize_metadata(block_C, block_MD, block_S, num_sm_parts, options);
int ngroups = options.h_q / options.h_kv;
int num_heads = options.h_kv;
int seqlen_q = options.s_q * ngroups;
// LSE: [batch_size, num_heads, seqlen_q]
block_LSE.resize(options.b * num_heads * seqlen_q);
block_LSE_T.resize(options.b * seqlen_q * num_heads);
// Output: [batch_size, seqlen_q, num_heads, head_size_v]
block_O.resize(options.b * seqlen_q * num_heads * options.dv);
block_O_T.resize(options.b * seqlen_q * num_heads * options.dv);
auto softmax_lse_size = (options.b + num_sm_parts) * num_heads * seqlen_q;
auto out_accum_size = (options.b + num_sm_parts) * num_heads * seqlen_q * options.dv;
block_LSE_Accum.resize(softmax_lse_size);
block_O_Accum.resize(out_accum_size);
}
/// Execute a given example Flash MLA computation
void run(Options &options)
{
cudaDeviceProp props;
int current_device;
CUDA_CHECK(cudaGetDevice(&current_device));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device));
// TODO: use vcache which is None in the example
auto batch_size = options.b;
auto seqlen_q_ori = options.s_q;
auto num_heads_ori = options.h_q;
auto head_size = options.d;
auto head_size_v = options.dv;
auto num_heads_k = options.h_kv;
auto page_block_size = options.block_size;
int total_blocks, max_num_blocks_per_seq;
int num_sm_parts;
assert(head_size % 8 == 0);
assert(head_size_v % 32 == 0);
initialize(options, total_blocks, max_num_blocks_per_seq, num_sm_parts);
assert(batch_size > 0);
assert(num_heads_ori % num_heads_k == 0);
bool is_causal = seqlen_q_ori == 1 ? false : options.causal;
int ngroups = num_heads_ori / num_heads_k;
int seqlen_q = seqlen_q_ori * ngroups;
int num_heads = num_heads_k;
// preprocess the query
transpose(
block_Q, block_Q_T,
cute::make_shape(options.b, seqlen_q_ori, num_heads_k, ngroups, options.d));
cudaStream_t stream{nullptr};
// set the parameters
Flash_fwd_mla_params kernel_params{};
kernel_params.b = options.b;
kernel_params.seqlen_q = options.s_q;
kernel_params.d = options.d;
kernel_params.d_v = options.dv;
kernel_params.h = options.h_q;
kernel_params.h_h_k_ratio = num_heads_ori / num_heads_k;
kernel_params.ngroups = ngroups;
kernel_params.q_ptr = block_Q_T.data().get();
kernel_params.k_ptr = block_K.data().get();
// TODO: block_V is not used in the example
kernel_params.v_ptr = block_K.data().get();
kernel_params.o_ptr = block_O.data().get();
kernel_params.softmax_lse_ptr = block_LSE.data().get();
kernel_params.q_batch_stride = seqlen_q * num_heads * options.d;
kernel_params.k_batch_stride = page_block_size * options.h_kv * options.d;
kernel_params.v_batch_stride = page_block_size * options.h_kv * options.dv;
kernel_params.o_batch_stride = options.s_q * options.h_q * options.dv;
kernel_params.q_row_stride = num_heads * options.d;
kernel_params.k_row_stride = options.h_kv * options.d;
kernel_params.v_row_stride = options.h_kv * options.dv;
kernel_params.o_row_stride = options.h_q * options.dv;
kernel_params.q_head_stride = options.d;
kernel_params.k_head_stride = options.d;
kernel_params.v_head_stride = options.dv;
kernel_params.o_head_stride = options.dv;
kernel_params.block_table = block_T.data().get();
kernel_params.block_table_batch_stride = max_num_blocks_per_seq;
kernel_params.page_block_size = page_block_size;
kernel_params.tile_scheduler_metadata_ptr = block_MD.data().get();
kernel_params.num_splits_ptr = block_S.data().get();
kernel_params.softmax_lseaccum_ptr = block_LSE_Accum.data().get();
kernel_params.oaccum_ptr = block_O_Accum.data().get();
kernel_params.is_causal = is_causal;
kernel_params.scale_softmax = options.softmax_scale;
kernel_params.scale_softmax_log2 = std::log2(options.softmax_scale);
kernel_params.cu_seqlens_k = block_C.data().get();
kernel_params.num_sm_parts = num_sm_parts;
assert(head_size == 576);
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(kernel_params, stream);
CUDA_CHECK(cudaDeviceSynchronize());
transpose(
block_O, block_O_T,
cute::make_shape(options.b, seqlen_q_ori, ngroups, num_heads_k, options.dv));
transpose(
block_LSE, block_LSE_T,
cute::make_shape(options.b, num_heads_k, seqlen_q_ori, ngroups, 1));
// TODO: reference check
printf("run done\n");
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 7)) {
std::cerr << "This example requires CUDA 12.7 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 9 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 90)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
TestBed testbed{};
testbed.run(options);
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,36 +0,0 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Sparse kernel in this example triggers an ICE in gcc 7.5
if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0))
cutlass_example_add_executable(
68_hopper_flash_mla
68_hopper_flash_mla.cu
)
endif()

View File

@ -0,0 +1,12 @@
# Hopper FlashMLA - Examples
The codes in this example are migrated from [FlashMLA](https://github.com/deepseek-ai/FlashMLA/tree/main), it implements an efficient MLA decoding kernel for Hopper GPU.
# Run the example
### Install
```
python setup.py install
```
### Run the test
```
python tests/test_flash_mla.py
```

View File

@ -0,0 +1,213 @@
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/fast_math.h>
#include "flash_mla.h"
#include "static_switch.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
std::vector<at::Tensor>
get_mla_metadata(
at::Tensor &seqlens_k,
const int num_heads_per_head_k,
const int num_heads_k
) {
// This should match the logic in the MLA kernel.
static constexpr int block_size_m = 64;
static constexpr int block_size_n = 64;
static constexpr int fixed_overhead_num_blocks = 5;
CHECK_DEVICE(seqlens_k);
TORCH_CHECK(seqlens_k.is_contiguous());
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
int batch_size = seqlens_k.size(0);
int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
auto options = seqlens_k.options();
auto dprops = at::cuda::getCurrentDeviceProperties();
int sm_count = dprops->multiProcessorCount;
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
auto num_splits = torch::empty({batch_size + 1}, options);
int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
int *num_splits_ptr = num_splits.data_ptr<int>();
at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
Mla_metadata_params params = {};
params.seqlens_k_ptr = seqlens_k_ptr;
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
params.num_splits_ptr = num_splits_ptr;
params.batch_size = batch_size;
params.block_size_n = block_size_n;
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
params.num_sm_parts = num_sm_parts;
get_mla_metadata_func(params, stream);
return {tile_scheduler_metadata, num_splits};
}
std::vector<at::Tensor>
mha_fwd_kvcache_mla(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits // batch_size + 1
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90);
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q.dtype();
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_ori = sizes[2];
const int head_size = sizes[3];
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int ngroups = num_heads_ori / num_heads_k;
const int seqlen_q = seqlen_q_ori * ngroups;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
.reshape({batch_size, seqlen_q, num_heads, head_size});
int head_size_k = head_size;
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
Flash_fwd_mla_params params = {};
// Set the sizes.
params.b = batch_size;
params.seqlen_q = seqlen_q;
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
params.h = num_heads;
params.h_h_k_ratio = num_heads / num_heads_k;
params.ngroups = ngroups;
params.is_causal = is_causal;
params.d = head_size;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.v_ptr = vcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.v_batch_stride = vcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(-3);
params.k_row_stride = kcache.stride(-3);
params.v_row_stride = vcache.stride(-3);
params.o_row_stride = out.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = kcache.stride(-2);
params.v_head_stride = vcache.stride(-2);
params.o_head_stride = out.stride(-2);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_CONTIGUOUS(tile_scheduler_metadata);
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
params.num_sm_parts = tile_scheduler_metadata.size(0);
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
CHECK_DEVICE(num_splits);
CHECK_CONTIGUOUS(num_splits);
params.num_splits_ptr = num_splits.data_ptr<int>();
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);
if (q_dtype == torch::kBFloat16) {
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
}
#ifndef FLASH_MLA_DISABLE_FP16
else if (q_dtype == torch::kHalf) {
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
}
#endif
else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
return {out, softmax_lse};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashMLA";
m.def("get_mla_metadata", &get_mla_metadata);
m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla);
}

View File

@ -0,0 +1,3 @@
#include "flash_fwd_mla_kernel.h"
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);

View File

@ -0,0 +1,3 @@
#include "flash_fwd_mla_kernel.h"
template void run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);

View File

@ -1,5 +1,3 @@
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/flash_fwd_mla_kernel.h
#pragma once
#include <cute/tensor.hpp>
@ -578,7 +576,6 @@ template<typename Kernel_traits, typename SharedStorage>
void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
auto kernel = &flash::flash_fwd_splitkv_mla_kernel<Kernel_traits, Is_causal, SharedStorage>;
constexpr size_t smem_size = sizeof(SharedStorage);
@ -586,6 +583,7 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params &params, cudaStream_t stream
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();
dim3 grid_combine(params.b * params.h * params.seqlen_q);
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
@ -603,79 +601,3 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream)
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
}
static constexpr int MaxBatchSize = 4096;
__global__ void __launch_bounds__(256, 1, 1)
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
int *seqlens_k_ptr = params.seqlens_k_ptr;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
int *num_splits_ptr = params.num_splits_ptr;
int batch_size = params.batch_size;
int block_size_n = params.block_size_n;
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
int num_sm_parts = params.num_sm_parts;
__shared__ int num_blocks_shared[MaxBatchSize];
__shared__ int num_splits_shared[MaxBatchSize];
int total_num_blocks = 0;
for (int i = threadIdx.x; i < batch_size; i += 32) {
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
num_blocks_shared[i] = num_blocks;
}
for (int offset = 16; offset >= 1; offset /= 2) {
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
}
__syncwarp();
if (threadIdx.x == 0) {
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
num_splits_shared[0] = 0;
for (int i = 0; i < num_sm_parts; ++i) {
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
tile_scheduler_metadata0[0] = now_idx;
tile_scheduler_metadata0[1] = now_block * block_size_n;
tile_scheduler_metadata1 = now_n_split_idx;
int remain_payload = payload;
while (now_idx < batch_size) {
int num_blocks = num_blocks_shared[now_idx];
int now_remain_blocks = num_blocks - now_block;
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
cum_num_splits += now_n_split_idx + 1;
num_splits_shared[now_idx + 1] = cum_num_splits;
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
++now_idx;
now_block = 0;
now_n_split_idx = 0;
} else {
if (remain_payload - fixed_overhead_num_blocks > 0) {
now_block += remain_payload - fixed_overhead_num_blocks;
++now_n_split_idx;
remain_payload = 0;
}
break;
}
}
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
}
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
}
__syncwarp();
for (int i = threadIdx.x; i <= batch_size; i += 32) {
num_splits_ptr[i] = num_splits_shared[i];
}
}
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.batch_size < MaxBatchSize);
get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
}

View File

@ -0,0 +1,77 @@
#include "flash_fwd_mla_kernel.h"
static constexpr int MaxBatchSize = 4096;
__global__ void __launch_bounds__(256, 1, 1)
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
int *seqlens_k_ptr = params.seqlens_k_ptr;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
int *num_splits_ptr = params.num_splits_ptr;
int batch_size = params.batch_size;
int block_size_n = params.block_size_n;
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
int num_sm_parts = params.num_sm_parts;
__shared__ int num_blocks_shared[MaxBatchSize];
__shared__ int num_splits_shared[MaxBatchSize];
int total_num_blocks = 0;
for (int i = threadIdx.x; i < batch_size; i += 32) {
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
num_blocks_shared[i] = num_blocks;
}
for (int offset = 16; offset >= 1; offset /= 2) {
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
}
__syncwarp();
if (threadIdx.x == 0) {
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
num_splits_shared[0] = 0;
for (int i = 0; i < num_sm_parts; ++i) {
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
tile_scheduler_metadata0[0] = now_idx;
tile_scheduler_metadata0[1] = now_block * block_size_n;
tile_scheduler_metadata1 = now_n_split_idx;
int remain_payload = payload;
while (now_idx < batch_size) {
int num_blocks = num_blocks_shared[now_idx];
int now_remain_blocks = num_blocks - now_block;
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
cum_num_splits += now_n_split_idx + 1;
num_splits_shared[now_idx + 1] = cum_num_splits;
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
++now_idx;
now_block = 0;
now_n_split_idx = 0;
} else {
if (remain_payload - fixed_overhead_num_blocks > 0) {
now_block += remain_payload - fixed_overhead_num_blocks;
++now_n_split_idx;
remain_payload = 0;
}
break;
}
}
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
}
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
}
__syncwarp();
for (int i = threadIdx.x; i <= batch_size; i += 32) {
num_splits_ptr[i] = num_splits_shared[i];
}
}
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.batch_size < MaxBatchSize);
get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
}

View File

@ -1,5 +1,3 @@
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/flash_mla.h
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -62,4 +60,4 @@ struct Mla_metadata_params {
int num_sm_parts;
};
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream);
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream);

View File

@ -1,5 +1,3 @@
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/named_barrier.h
#pragma once
#include "cutlass/barrier.h"
@ -14,4 +12,4 @@ enum class NamedBarriers {
SoftmaxReady = 2,
};
} // flash
} // flash

View File

@ -194,4 +194,4 @@ struct Softmax {
};
};
} // namespace flash
} // namespace flash

View File

@ -1,5 +1,3 @@
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/static_switch.h
#pragma once
#define CHECK_CUDA(call) \
@ -64,4 +62,4 @@
} else { \
FLASH_ASSERT(false); \
} \
}()
}()

View File

@ -235,4 +235,4 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layou
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
} // namespace flash

View File

@ -1,79 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Process the input&output data in Flash MLA kernels
*/
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/counting_iterator.h>
template <typename T>
struct fill_nan_functor {
T* blocked_k;
const int* cache_seqlens;
int max_seqlen_pad, h_kv, d;
__host__ __device__
fill_nan_functor(T* bk, const int* cs, int msp, int hkv, int dd) :
blocked_k(bk), cache_seqlens(cs), max_seqlen_pad(msp), h_kv(hkv), d(dd) {}
__host__ __device__
void operator()(int idx) {
auto NAN_VALUE = std::numeric_limits<T>::quiet_NaN();
int h_d_size = h_kv * d;
int seq_h_d_size = max_seqlen_pad * h_d_size;
int batch_idx = idx / seq_h_d_size;
int pos_in_seq = (idx % seq_h_d_size) / h_d_size;
if (pos_in_seq >= cache_seqlens[batch_idx]) {
blocked_k[idx] = NAN_VALUE; // NaN
}
}
};
template <typename T>
void fill_nan(
T* d_blocked_k,
const int* d_cache_seqlens,
int b, int max_seqlen_pad, int h_kv, int d
) {
int total_elements = b * max_seqlen_pad * h_kv * d;
thrust::for_each(
thrust::device,
thrust::counting_iterator<int>(0),
thrust::counting_iterator<int>(total_elements),
fill_nan_functor(d_blocked_k, d_cache_seqlens, max_seqlen_pad, h_kv, d)
);
}

View File

@ -0,0 +1,6 @@
__version__ = "1.0.0"
from flash_mla.flash_mla_interface import (
get_mla_metadata,
flash_mla_with_kvcache,
)

View File

@ -0,0 +1,67 @@
from typing import Optional, Tuple
import torch
import flash_mla_cuda
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
)
return out, softmax_lse

View File

@ -0,0 +1,87 @@
import os
from pathlib import Path
from datetime import datetime
from setuptools import setup, find_packages
from torch.utils.cpp_extension import (
BuildExtension,
CUDAExtension,
IS_WINDOWS,
)
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
def append_nvcc_threads(nvcc_extra_args):
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
return nvcc_extra_args + ["--threads", nvcc_threads]
def get_sources():
sources = [
"csrc/flash_api.cpp",
"csrc/flash_fwd_mla_bf16_sm90.cu",
"csrc/flash_fwd_mla_metadata.cu",
]
if not DISABLE_FP16:
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
return sources
def get_features_args():
features_args = []
if DISABLE_FP16:
features_args.append("-DFLASH_MLA_DISABLE_FP16")
return features_args
cc_flag = []
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90a,code=sm_90a")
this_dir = os.path.dirname(os.path.abspath(__file__))
if IS_WINDOWS:
cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"]
else:
cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"]
ext_modules = []
ext_modules.append(
CUDAExtension(
name="flash_mla_cuda",
sources=get_sources(),
extra_compile_args={
"cxx": cxx_args + get_features_args(),
"nvcc": append_nvcc_threads(
[
"-O3",
"-std=c++17",
"-DNDEBUG",
"-D_USE_MATH_DEFINES",
"-Wno-deprecated-declarations",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v,--register-usage-level=10"
]
+ cc_flag
) + get_features_args(),
},
include_dirs=[
Path(this_dir) / "csrc",
Path(this_dir) / ".." / ".." / "include",
],
)
)
setup(
name="flash_mla",
version="1.0.0",
packages=find_packages(include=['flash_mla']),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
)

View File

@ -0,0 +1,153 @@
import argparse
import math
import random
import torch
import triton
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert cos_diff < 1e-5
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}"
)
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
float("nan")
)
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv
)
def flash_mla():
return flash_mla_with_kvcache(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
)
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
cal_diff(out_flash, out_torch, "out")
cal_diff(lse_flash, lse_torch, "lse")
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(q.dtype).bits // 8
)
print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def main(torch_dtype):
device = torch.device("cuda:0")
torch.set_default_dtype(torch_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
h_kv = 1
d, dv = 576, 512
causal = True
for b in [128]:
for s in [4096, 8192]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
choices=["bf16", "fp16"],
default="bf16",
help="Data type to use for testing (bf16 or fp16)",
)
args = parser.parse_args()
torch_dtype = torch.bfloat16
if args.dtype == "fp16":
torch_dtype = torch.float16
main(torch_dtype)

View File

@ -1,131 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
using namespace cutlass;
#include "cutlass/arch/arch.h"
template <class Element_>
struct TransposeKernel {
using Element = Element_;
// TODO: use more threads to copy tensor
static constexpr int MaxThreadsPerBlock = 1;
static constexpr int MinBlocksPerMultiprocessor = 1;
using ArchTag = cutlass::arch::Sm90;
static constexpr int AlignmentBytes = 16;
struct SharedStorage {
/* empty */
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
// batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size
using ProblemShape = cute::tuple<int, int, int, int, int>;
struct Arguments {
const Element *ptrS{nullptr};
Element *ptrD{nullptr};
ProblemShape problem_shape{};
};
struct Params {
const Element *ptrS{nullptr};
Element *ptrD{nullptr};
ProblemShape problem_shape{};
};
static Params
to_underlying_arguments(Arguments const&args, void* workspace) {
return Params{
args.ptrS,
args.ptrD,
args.problem_shape
};
}
static Status
can_implement(Arguments const& args) {
return Status::kSuccess;
}
static size_t
get_workspace_size(Arguments const& args) {
return size_t(0);
}
static Status
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr) {
return Status::kSuccess;
}
static dim3
get_grid_shape(Params const& params) {
auto [B, S, H, G, D] = params.problem_shape;
return dim3(B*S, H, G);
}
static dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void
operator()(Params params, [[maybe_unused]] char* smem_buf = nullptr) {
auto [B, S, H, G, D] = params.problem_shape;
// default is column-major layout
auto src_layout_ = make_layout(make_shape(D,G,H,S,B));
auto src_layout = make_layout(reverse(src_layout_.shape()), reverse(src_layout_.stride()));
auto dst_layout_ = make_layout(make_shape(D,H,G,S,B));
auto dst_layout = make_layout(reverse(dst_layout_.shape()), reverse(dst_layout_.stride()));
auto src_tensor = make_tensor(
make_gmem_ptr(params.ptrS),
group<0,2>(src_layout)
);
auto dst_tensor = make_tensor(
make_gmem_ptr(params.ptrD),
group<0,2>(dst_layout)
);
auto tS = src_tensor(blockIdx.x, blockIdx.y, blockIdx.z, _);
auto tD = dst_tensor(blockIdx.x, blockIdx.y, blockIdx.z, _);
copy(tS, tD);
}
};