Migrate FlashMLA codes to example. (#2135)
This commit is contained in:
@ -1,542 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A Hopper CUTLASS example for Flash MLA.
|
||||
*/
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
|
||||
#include <thrust/universal_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
#include "flash_mla.h"
|
||||
#include "fill_nan.h"
|
||||
#include "transform.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CUDA_CHECK(status) \
|
||||
{ \
|
||||
cudaError_t error = status; \
|
||||
if (error != cudaSuccess) { \
|
||||
std::cerr << "CUDA error: " << cudaGetErrorString(error) << " at " << \
|
||||
__FILE__ << ":" << __LINE__ << std::endl; \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
int iterations;
|
||||
int b, s, h_q, s_q;
|
||||
int h_kv, d, dv;
|
||||
float softmax_scale;
|
||||
bool varlen;
|
||||
bool causal;
|
||||
|
||||
static constexpr int block_size = 64;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
b(128), s(4096), h_q(16), s_q(1),
|
||||
h_kv(1), d(576), dv(512),
|
||||
varlen(false),
|
||||
causal(true),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
Options defaults;
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("b", b, defaults.b);
|
||||
cmd.get_cmd_line_argument("s", s, defaults.s);
|
||||
cmd.get_cmd_line_argument("h_q", h_q, defaults.h_q);
|
||||
cmd.get_cmd_line_argument("s_q", s_q, defaults.s_q);
|
||||
cmd.get_cmd_line_argument("h_kv", h_kv, defaults.h_kv);
|
||||
cmd.get_cmd_line_argument("d", d, defaults.d);
|
||||
cmd.get_cmd_line_argument("dv", dv, defaults.dv);
|
||||
|
||||
if (cmd.check_cmd_line_flag("varlen")) {
|
||||
varlen = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
|
||||
softmax_scale = 1 / std::sqrt(d);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "79_hopper_flash_mla\n\n"
|
||||
<< " Hopper Flash MLA kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --b=<int> Sets the batch size\n"
|
||||
<< " --s=<int> Sets the sequence length\n"
|
||||
<< " --h_q=<int> Sets the number of heads\n"
|
||||
<< " --s_q=<int> Sets the sequence length of the query\n"
|
||||
<< " --varlen Sets the varlen as true or false\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// TOOD:Compute performance in GFLOP
|
||||
|
||||
};
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element>
|
||||
static void
|
||||
initialize_values(
|
||||
thrust::universal_vector<Element>& dst_ptr,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed,
|
||||
Element var = Element(1.f)) {
|
||||
if (cutlass::Distribution::Uniform == dist_kind) {
|
||||
int scope = 2;
|
||||
cutlass::reference::host::BlockFillRandomUniform(
|
||||
dst_ptr.data().get(), dst_ptr.size(), seed, scope, -scope, 0);
|
||||
}
|
||||
else if (cutlass::Distribution::AllZeros == dist_kind) {
|
||||
cutlass::reference::host::BlockFillRandomUniform(
|
||||
dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0, 0);
|
||||
}
|
||||
else if (cutlass::Distribution::AllOnes == dist_kind) {
|
||||
cutlass::reference::host::BlockFillRandomUniform(
|
||||
dst_ptr.data().get(), dst_ptr.size(), seed, 1, 1, 0);
|
||||
}
|
||||
else if (cutlass::Distribution::Gaussian == dist_kind) {
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
dst_ptr.data().get(), dst_ptr.size(), seed, (Element) 0, var);
|
||||
}
|
||||
else if (cutlass::Distribution::Sequential == dist_kind) {
|
||||
cutlass::reference::host::BlockFillSequential(dst_ptr.data().get(), dst_ptr.size());
|
||||
}
|
||||
else {
|
||||
std::cerr << "Invalid distribution kind!\n.";
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
void initialize_varlen(thrust::universal_vector<int32_t>& block_C, const Options &options) {
|
||||
|
||||
block_C.resize(options.b);
|
||||
|
||||
std::vector<int32_t> cache_seqlens(options.b, options.s);
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
|
||||
std::normal_distribution<float> distribution(options.s, options.s / 2.0f);
|
||||
|
||||
for (int i = 0; i < options.b; ++i) {
|
||||
if (options.varlen) {
|
||||
float random_length = distribution(gen);
|
||||
cache_seqlens[i] = std::max(static_cast<int32_t>(random_length), options.s_q);
|
||||
} else {
|
||||
cache_seqlens[i] = options.s;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::DeviceAllocation<int32_t> d_cache_seqlens(options.b);
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
block_C.data().get(),
|
||||
cache_seqlens.data(),
|
||||
options.b * sizeof(int32_t),
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
}
|
||||
|
||||
auto initialize_metadata(
|
||||
thrust::universal_vector<int32_t> &block_C,
|
||||
thrust::universal_vector<int32_t> &block_MD, thrust::universal_vector<int32_t> &block_S,
|
||||
int& num_sm_parts,
|
||||
const Options &options) {
|
||||
|
||||
// This should match the logic in the MLA kernel.
|
||||
static constexpr int block_size_m = 64;
|
||||
static constexpr int block_size_n = 64;
|
||||
static constexpr int fixed_overhead_num_blocks = 5;
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device));
|
||||
|
||||
auto batch_size = options.b;
|
||||
int sm_count = props.multiProcessorCount;
|
||||
|
||||
num_sm_parts = sm_count / options.h_kv / cutlass::ceil_div(options.h_kv, block_size_m);
|
||||
|
||||
block_MD.resize(num_sm_parts * TileSchedulerMetaDataSize);
|
||||
block_S.resize(options.b + 1);
|
||||
|
||||
Mla_metadata_params params{};
|
||||
params.seqlens_k_ptr = block_C.data().get();
|
||||
params.tile_scheduler_metadata_ptr = block_MD.data().get();
|
||||
params.num_splits_ptr = block_S.data().get();
|
||||
params.batch_size = batch_size;
|
||||
params.block_size_n = block_size_n;
|
||||
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
|
||||
params.num_sm_parts = num_sm_parts;
|
||||
|
||||
cudaStream_t stream{nullptr};
|
||||
|
||||
get_mla_metadata_func(params, stream);
|
||||
}
|
||||
|
||||
// only transpose the dimensions 2 and 3
|
||||
template <class Element>
|
||||
void transpose(
|
||||
thrust::universal_vector<Element> &block_S,
|
||||
thrust::universal_vector<Element> &block_D,
|
||||
cute::tuple<int, int, int, int, int> problem_shape) {
|
||||
|
||||
using Operator = cutlass::transform::device::TransformUniversalAdapter<TransposeKernel<Element>>;
|
||||
|
||||
cudaError_t result;
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the Transpose kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
}
|
||||
|
||||
typename Operator::Arguments arguments{
|
||||
block_S.data().get(),
|
||||
block_D.data().get(),
|
||||
problem_shape,
|
||||
};
|
||||
|
||||
Operator op;
|
||||
|
||||
size_t workspace_size = Operator::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
status = op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
struct TestBed {
|
||||
using Element = cutlass::bfloat16_t;
|
||||
using ElementAcc = float;
|
||||
|
||||
thrust::universal_vector<Element> block_Q; // query
|
||||
thrust::universal_vector<Element> block_Q_T; // query transpose
|
||||
thrust::universal_vector<Element> block_K; // blocked key
|
||||
thrust::universal_vector<int32_t> block_T; // block table
|
||||
thrust::universal_vector<int32_t> block_C; // cache seqlens
|
||||
// TODO: block_V is not used in the example
|
||||
// thrust::universal_vector<Element> block_V; // dv
|
||||
thrust::universal_vector<int32_t> block_MD; // mla metadata
|
||||
thrust::universal_vector<int32_t> block_S; // num splits
|
||||
thrust::universal_vector<Element> block_O; // output
|
||||
thrust::universal_vector<Element> block_LSE; // lse
|
||||
thrust::universal_vector<Element> block_O_T; // output transpose
|
||||
thrust::universal_vector<Element> block_LSE_T; // lse transpose
|
||||
thrust::universal_vector<ElementAcc> block_O_Accum; // output
|
||||
thrust::universal_vector<ElementAcc> block_LSE_Accum; // lse
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(
|
||||
const Options &options,
|
||||
int& total_blocks, int& blocks_per_seq, int& num_sm_parts,
|
||||
uint64_t seed = 2025) {
|
||||
|
||||
initialize_varlen(block_C, options);
|
||||
|
||||
thrust::device_ptr<int32_t> d_ptr(block_C.data().get());
|
||||
|
||||
int64_t total_seqlens = thrust::reduce(d_ptr, d_ptr + options.b);
|
||||
float sum = static_cast<float>(total_seqlens);
|
||||
int32_t mean_seqlens = static_cast<int32_t>(sum / options.b);
|
||||
int32_t max_seqlen = thrust::reduce(d_ptr, d_ptr + options.b,
|
||||
0,
|
||||
thrust::maximum<int32_t>());
|
||||
int max_seqlen_pad = ((max_seqlen + 255) / 256) * 256;
|
||||
|
||||
blocks_per_seq = max_seqlen_pad / options.block_size;
|
||||
total_blocks = options.b * blocks_per_seq;
|
||||
|
||||
// Query: [b, s_q, h_q, d]
|
||||
block_Q.resize(options.b * options.s_q * options.h_q * options.d);
|
||||
block_Q_T.resize(options.b * options.s_q * options.h_q * options.d);
|
||||
|
||||
// Block table: [b, max_num_blocks_per_seq]
|
||||
block_T.resize(total_blocks);
|
||||
|
||||
// Key: [b, max_num_blocks_per_seq, block_size, h_kv, d]
|
||||
block_K.resize(total_blocks * options.block_size * options.h_kv * options.d);
|
||||
|
||||
initialize_values(block_Q, cutlass::Distribution::Gaussian, seed + 1);
|
||||
initialize_values(block_T, cutlass::Distribution::Sequential, seed + 3);
|
||||
initialize_values(block_K, cutlass::Distribution::Gaussian, seed + 5);
|
||||
|
||||
// Set the exceeding part to NaN
|
||||
fill_nan(block_K.data().get(), block_C.data().get(),
|
||||
options.b, max_seqlen_pad, options.h_kv, options.d);
|
||||
|
||||
initialize_metadata(block_C, block_MD, block_S, num_sm_parts, options);
|
||||
|
||||
int ngroups = options.h_q / options.h_kv;
|
||||
int num_heads = options.h_kv;
|
||||
int seqlen_q = options.s_q * ngroups;
|
||||
|
||||
// LSE: [batch_size, num_heads, seqlen_q]
|
||||
block_LSE.resize(options.b * num_heads * seqlen_q);
|
||||
block_LSE_T.resize(options.b * seqlen_q * num_heads);
|
||||
// Output: [batch_size, seqlen_q, num_heads, head_size_v]
|
||||
block_O.resize(options.b * seqlen_q * num_heads * options.dv);
|
||||
block_O_T.resize(options.b * seqlen_q * num_heads * options.dv);
|
||||
|
||||
auto softmax_lse_size = (options.b + num_sm_parts) * num_heads * seqlen_q;
|
||||
auto out_accum_size = (options.b + num_sm_parts) * num_heads * seqlen_q * options.dv;
|
||||
|
||||
block_LSE_Accum.resize(softmax_lse_size);
|
||||
block_O_Accum.resize(out_accum_size);
|
||||
}
|
||||
|
||||
/// Execute a given example Flash MLA computation
|
||||
void run(Options &options)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
int current_device;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device));
|
||||
|
||||
// TODO: use vcache which is None in the example
|
||||
|
||||
auto batch_size = options.b;
|
||||
auto seqlen_q_ori = options.s_q;
|
||||
auto num_heads_ori = options.h_q;
|
||||
auto head_size = options.d;
|
||||
auto head_size_v = options.dv;
|
||||
auto num_heads_k = options.h_kv;
|
||||
auto page_block_size = options.block_size;
|
||||
int total_blocks, max_num_blocks_per_seq;
|
||||
int num_sm_parts;
|
||||
|
||||
assert(head_size % 8 == 0);
|
||||
assert(head_size_v % 32 == 0);
|
||||
|
||||
initialize(options, total_blocks, max_num_blocks_per_seq, num_sm_parts);
|
||||
|
||||
assert(batch_size > 0);
|
||||
assert(num_heads_ori % num_heads_k == 0);
|
||||
|
||||
bool is_causal = seqlen_q_ori == 1 ? false : options.causal;
|
||||
|
||||
int ngroups = num_heads_ori / num_heads_k;
|
||||
int seqlen_q = seqlen_q_ori * ngroups;
|
||||
int num_heads = num_heads_k;
|
||||
|
||||
// preprocess the query
|
||||
transpose(
|
||||
block_Q, block_Q_T,
|
||||
cute::make_shape(options.b, seqlen_q_ori, num_heads_k, ngroups, options.d));
|
||||
|
||||
cudaStream_t stream{nullptr};
|
||||
|
||||
// set the parameters
|
||||
Flash_fwd_mla_params kernel_params{};
|
||||
|
||||
kernel_params.b = options.b;
|
||||
kernel_params.seqlen_q = options.s_q;
|
||||
kernel_params.d = options.d;
|
||||
kernel_params.d_v = options.dv;
|
||||
kernel_params.h = options.h_q;
|
||||
kernel_params.h_h_k_ratio = num_heads_ori / num_heads_k;
|
||||
kernel_params.ngroups = ngroups;
|
||||
|
||||
kernel_params.q_ptr = block_Q_T.data().get();
|
||||
kernel_params.k_ptr = block_K.data().get();
|
||||
// TODO: block_V is not used in the example
|
||||
kernel_params.v_ptr = block_K.data().get();
|
||||
kernel_params.o_ptr = block_O.data().get();
|
||||
kernel_params.softmax_lse_ptr = block_LSE.data().get();
|
||||
|
||||
kernel_params.q_batch_stride = seqlen_q * num_heads * options.d;
|
||||
kernel_params.k_batch_stride = page_block_size * options.h_kv * options.d;
|
||||
kernel_params.v_batch_stride = page_block_size * options.h_kv * options.dv;
|
||||
kernel_params.o_batch_stride = options.s_q * options.h_q * options.dv;
|
||||
|
||||
kernel_params.q_row_stride = num_heads * options.d;
|
||||
kernel_params.k_row_stride = options.h_kv * options.d;
|
||||
kernel_params.v_row_stride = options.h_kv * options.dv;
|
||||
kernel_params.o_row_stride = options.h_q * options.dv;
|
||||
|
||||
kernel_params.q_head_stride = options.d;
|
||||
kernel_params.k_head_stride = options.d;
|
||||
kernel_params.v_head_stride = options.dv;
|
||||
kernel_params.o_head_stride = options.dv;
|
||||
|
||||
kernel_params.block_table = block_T.data().get();
|
||||
kernel_params.block_table_batch_stride = max_num_blocks_per_seq;
|
||||
kernel_params.page_block_size = page_block_size;
|
||||
|
||||
kernel_params.tile_scheduler_metadata_ptr = block_MD.data().get();
|
||||
kernel_params.num_splits_ptr = block_S.data().get();
|
||||
|
||||
kernel_params.softmax_lseaccum_ptr = block_LSE_Accum.data().get();
|
||||
kernel_params.oaccum_ptr = block_O_Accum.data().get();
|
||||
|
||||
kernel_params.is_causal = is_causal;
|
||||
kernel_params.scale_softmax = options.softmax_scale;
|
||||
kernel_params.scale_softmax_log2 = std::log2(options.softmax_scale);
|
||||
|
||||
kernel_params.cu_seqlens_k = block_C.data().get();
|
||||
|
||||
kernel_params.num_sm_parts = num_sm_parts;
|
||||
|
||||
assert(head_size == 576);
|
||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(kernel_params, stream);
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
transpose(
|
||||
block_O, block_O_T,
|
||||
cute::make_shape(options.b, seqlen_q_ori, ngroups, num_heads_k, options.dv));
|
||||
transpose(
|
||||
block_LSE, block_LSE_T,
|
||||
cute::make_shape(options.b, num_heads_k, seqlen_q_ori, ngroups, 1));
|
||||
|
||||
// TODO: reference check
|
||||
|
||||
printf("run done\n");
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 7)) {
|
||||
std::cerr << "This example requires CUDA 12.7 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 90)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
TestBed testbed{};
|
||||
testbed.run(options);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,36 +0,0 @@
|
||||
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Sparse kernel in this example triggers an ICE in gcc 7.5
|
||||
if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0))
|
||||
cutlass_example_add_executable(
|
||||
68_hopper_flash_mla
|
||||
68_hopper_flash_mla.cu
|
||||
)
|
||||
endif()
|
||||
12
examples/68_hopper_flash_mla/README.md
Normal file
12
examples/68_hopper_flash_mla/README.md
Normal file
@ -0,0 +1,12 @@
|
||||
# Hopper FlashMLA - Examples
|
||||
The codes in this example are migrated from [FlashMLA](https://github.com/deepseek-ai/FlashMLA/tree/main), it implements an efficient MLA decoding kernel for Hopper GPU.
|
||||
|
||||
# Run the example
|
||||
### Install
|
||||
```
|
||||
python setup.py install
|
||||
```
|
||||
### Run the test
|
||||
```
|
||||
python tests/test_flash_mla.py
|
||||
```
|
||||
213
examples/68_hopper_flash_mla/csrc/flash_api.cpp
Normal file
213
examples/68_hopper_flash_mla/csrc/flash_api.cpp
Normal file
@ -0,0 +1,213 @@
|
||||
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp
|
||||
|
||||
#include <torch/python.h>
|
||||
#include <torch/nn/functional.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cutlass/fast_math.h>
|
||||
|
||||
#include "flash_mla.h"
|
||||
#include "static_switch.h"
|
||||
|
||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
std::vector<at::Tensor>
|
||||
get_mla_metadata(
|
||||
at::Tensor &seqlens_k,
|
||||
const int num_heads_per_head_k,
|
||||
const int num_heads_k
|
||||
) {
|
||||
// This should match the logic in the MLA kernel.
|
||||
static constexpr int block_size_m = 64;
|
||||
static constexpr int block_size_n = 64;
|
||||
static constexpr int fixed_overhead_num_blocks = 5;
|
||||
|
||||
CHECK_DEVICE(seqlens_k);
|
||||
TORCH_CHECK(seqlens_k.is_contiguous());
|
||||
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
|
||||
|
||||
int batch_size = seqlens_k.size(0);
|
||||
int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
|
||||
auto options = seqlens_k.options();
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
int sm_count = dprops->multiProcessorCount;
|
||||
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
|
||||
|
||||
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
|
||||
auto num_splits = torch::empty({batch_size + 1}, options);
|
||||
int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
|
||||
int *num_splits_ptr = num_splits.data_ptr<int>();
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
Mla_metadata_params params = {};
|
||||
params.seqlens_k_ptr = seqlens_k_ptr;
|
||||
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
|
||||
params.num_splits_ptr = num_splits_ptr;
|
||||
params.batch_size = batch_size;
|
||||
params.block_size_n = block_size_n;
|
||||
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
|
||||
params.num_sm_parts = num_sm_parts;
|
||||
get_mla_metadata_func(params, stream);
|
||||
|
||||
return {tile_scheduler_metadata, num_splits};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd_kvcache_mla(
|
||||
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
||||
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
||||
const int head_size_v,
|
||||
const at::Tensor &seqlens_k, // batch_size
|
||||
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
||||
const at::Tensor &num_splits // batch_size + 1
|
||||
) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90);
|
||||
|
||||
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
|
||||
CHECK_DEVICE(block_table);
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
|
||||
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int seqlen_q_ori = sizes[1];
|
||||
const int num_heads_ori = sizes[2];
|
||||
const int head_size = sizes[3];
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
|
||||
|
||||
const int max_num_blocks_per_seq = block_table.size(1);
|
||||
const int num_blocks = kcache.size(0);
|
||||
const int page_block_size = kcache.size(1);
|
||||
const int num_heads_k = kcache.size(2);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
||||
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (seqlen_q_ori == 1) { is_causal = false; }
|
||||
|
||||
const int ngroups = num_heads_ori / num_heads_k;
|
||||
const int seqlen_q = seqlen_q_ori * ngroups;
|
||||
const int num_heads = num_heads_k;
|
||||
q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
|
||||
.reshape({batch_size, seqlen_q, num_heads, head_size});
|
||||
|
||||
int head_size_k = head_size;
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
||||
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
|
||||
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
|
||||
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
|
||||
|
||||
|
||||
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
|
||||
CHECK_DEVICE(seqlens_k);
|
||||
CHECK_CONTIGUOUS(seqlens_k);
|
||||
CHECK_SHAPE(seqlens_k, batch_size);
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
|
||||
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
|
||||
Flash_fwd_mla_params params = {};
|
||||
// Set the sizes.
|
||||
params.b = batch_size;
|
||||
params.seqlen_q = seqlen_q;
|
||||
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
|
||||
params.h = num_heads;
|
||||
params.h_h_k_ratio = num_heads / num_heads_k;
|
||||
params.ngroups = ngroups;
|
||||
params.is_causal = is_causal;
|
||||
params.d = head_size;
|
||||
params.d_v = head_size_v;
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
|
||||
// Set the pointers and strides.
|
||||
params.q_ptr = q.data_ptr();
|
||||
params.k_ptr = kcache.data_ptr();
|
||||
params.v_ptr = vcache.data_ptr();
|
||||
params.o_ptr = out.data_ptr();
|
||||
params.softmax_lse_ptr = softmax_lse.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.q_batch_stride = q.stride(0);
|
||||
params.k_batch_stride = kcache.stride(0);
|
||||
params.v_batch_stride = vcache.stride(0);
|
||||
params.o_batch_stride = out.stride(0);
|
||||
params.q_row_stride = q.stride(-3);
|
||||
params.k_row_stride = kcache.stride(-3);
|
||||
params.v_row_stride = vcache.stride(-3);
|
||||
params.o_row_stride = out.stride(-3);
|
||||
params.q_head_stride = q.stride(-2);
|
||||
params.k_head_stride = kcache.stride(-2);
|
||||
params.v_head_stride = vcache.stride(-2);
|
||||
params.o_head_stride = out.stride(-2);
|
||||
|
||||
params.block_table = block_table.data_ptr<int>();
|
||||
params.block_table_batch_stride = block_table.stride(0);
|
||||
params.page_block_size = page_block_size;
|
||||
|
||||
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
|
||||
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
|
||||
CHECK_DEVICE(tile_scheduler_metadata);
|
||||
CHECK_CONTIGUOUS(tile_scheduler_metadata);
|
||||
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
|
||||
params.num_sm_parts = tile_scheduler_metadata.size(0);
|
||||
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
|
||||
CHECK_DEVICE(num_splits);
|
||||
CHECK_CONTIGUOUS(num_splits);
|
||||
params.num_splits_ptr = num_splits.data_ptr<int>();
|
||||
|
||||
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
|
||||
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
||||
params.oaccum_ptr = out_accum.data_ptr();
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK(head_size == 576);
|
||||
|
||||
if (q_dtype == torch::kBFloat16) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
|
||||
}
|
||||
#ifndef FLASH_MLA_DISABLE_FP16
|
||||
else if (q_dtype == torch::kHalf) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported tensor dtype for query");
|
||||
}
|
||||
|
||||
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
||||
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
||||
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
|
||||
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
|
||||
|
||||
return {out, softmax_lse};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "FlashMLA";
|
||||
m.def("get_mla_metadata", &get_mla_metadata);
|
||||
m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla);
|
||||
}
|
||||
@ -0,0 +1,3 @@
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
@ -0,0 +1,3 @@
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
@ -1,5 +1,3 @@
|
||||
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/flash_fwd_mla_kernel.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
@ -578,7 +576,6 @@ template<typename Kernel_traits, typename SharedStorage>
|
||||
void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
||||
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
|
||||
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
|
||||
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
auto kernel = &flash::flash_fwd_splitkv_mla_kernel<Kernel_traits, Is_causal, SharedStorage>;
|
||||
constexpr size_t smem_size = sizeof(SharedStorage);
|
||||
@ -586,6 +583,7 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream
|
||||
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
});
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
|
||||
dim3 grid_combine(params.b * params.h * params.seqlen_q);
|
||||
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
|
||||
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
|
||||
@ -603,79 +601,3 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream)
|
||||
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
|
||||
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
|
||||
}
|
||||
|
||||
static constexpr int MaxBatchSize = 4096;
|
||||
|
||||
__global__ void __launch_bounds__(256, 1, 1)
|
||||
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
|
||||
int *seqlens_k_ptr = params.seqlens_k_ptr;
|
||||
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
|
||||
int *num_splits_ptr = params.num_splits_ptr;
|
||||
int batch_size = params.batch_size;
|
||||
int block_size_n = params.block_size_n;
|
||||
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
|
||||
int num_sm_parts = params.num_sm_parts;
|
||||
|
||||
__shared__ int num_blocks_shared[MaxBatchSize];
|
||||
__shared__ int num_splits_shared[MaxBatchSize];
|
||||
|
||||
int total_num_blocks = 0;
|
||||
for (int i = threadIdx.x; i < batch_size; i += 32) {
|
||||
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
|
||||
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
|
||||
num_blocks_shared[i] = num_blocks;
|
||||
}
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
|
||||
|
||||
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
|
||||
num_splits_shared[0] = 0;
|
||||
for (int i = 0; i < num_sm_parts; ++i) {
|
||||
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
|
||||
tile_scheduler_metadata0[0] = now_idx;
|
||||
tile_scheduler_metadata0[1] = now_block * block_size_n;
|
||||
tile_scheduler_metadata1 = now_n_split_idx;
|
||||
int remain_payload = payload;
|
||||
while (now_idx < batch_size) {
|
||||
int num_blocks = num_blocks_shared[now_idx];
|
||||
int now_remain_blocks = num_blocks - now_block;
|
||||
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
|
||||
cum_num_splits += now_n_split_idx + 1;
|
||||
num_splits_shared[now_idx + 1] = cum_num_splits;
|
||||
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
|
||||
++now_idx;
|
||||
now_block = 0;
|
||||
now_n_split_idx = 0;
|
||||
} else {
|
||||
if (remain_payload - fixed_overhead_num_blocks > 0) {
|
||||
now_block += remain_payload - fixed_overhead_num_blocks;
|
||||
++now_n_split_idx;
|
||||
remain_payload = 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
|
||||
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
|
||||
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
|
||||
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
|
||||
}
|
||||
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
for (int i = threadIdx.x; i <= batch_size; i += 32) {
|
||||
num_splits_ptr[i] = num_splits_shared[i];
|
||||
}
|
||||
}
|
||||
|
||||
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) {
|
||||
FLASH_ASSERT(params.batch_size < MaxBatchSize);
|
||||
get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
}
|
||||
77
examples/68_hopper_flash_mla/csrc/flash_fwd_mla_metadata.cu
Normal file
77
examples/68_hopper_flash_mla/csrc/flash_fwd_mla_metadata.cu
Normal file
@ -0,0 +1,77 @@
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
static constexpr int MaxBatchSize = 4096;
|
||||
|
||||
__global__ void __launch_bounds__(256, 1, 1)
|
||||
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
|
||||
int *seqlens_k_ptr = params.seqlens_k_ptr;
|
||||
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
|
||||
int *num_splits_ptr = params.num_splits_ptr;
|
||||
int batch_size = params.batch_size;
|
||||
int block_size_n = params.block_size_n;
|
||||
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
|
||||
int num_sm_parts = params.num_sm_parts;
|
||||
|
||||
__shared__ int num_blocks_shared[MaxBatchSize];
|
||||
__shared__ int num_splits_shared[MaxBatchSize];
|
||||
|
||||
int total_num_blocks = 0;
|
||||
for (int i = threadIdx.x; i < batch_size; i += 32) {
|
||||
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
|
||||
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
|
||||
num_blocks_shared[i] = num_blocks;
|
||||
}
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
|
||||
|
||||
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
|
||||
num_splits_shared[0] = 0;
|
||||
for (int i = 0; i < num_sm_parts; ++i) {
|
||||
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
|
||||
tile_scheduler_metadata0[0] = now_idx;
|
||||
tile_scheduler_metadata0[1] = now_block * block_size_n;
|
||||
tile_scheduler_metadata1 = now_n_split_idx;
|
||||
int remain_payload = payload;
|
||||
while (now_idx < batch_size) {
|
||||
int num_blocks = num_blocks_shared[now_idx];
|
||||
int now_remain_blocks = num_blocks - now_block;
|
||||
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
|
||||
cum_num_splits += now_n_split_idx + 1;
|
||||
num_splits_shared[now_idx + 1] = cum_num_splits;
|
||||
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
|
||||
++now_idx;
|
||||
now_block = 0;
|
||||
now_n_split_idx = 0;
|
||||
} else {
|
||||
if (remain_payload - fixed_overhead_num_blocks > 0) {
|
||||
now_block += remain_payload - fixed_overhead_num_blocks;
|
||||
++now_n_split_idx;
|
||||
remain_payload = 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
|
||||
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
|
||||
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
|
||||
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
|
||||
}
|
||||
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
for (int i = threadIdx.x; i <= batch_size; i += 32) {
|
||||
num_splits_ptr[i] = num_splits_shared[i];
|
||||
}
|
||||
}
|
||||
|
||||
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) {
|
||||
FLASH_ASSERT(params.batch_size < MaxBatchSize);
|
||||
get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
}
|
||||
@ -1,5 +1,3 @@
|
||||
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/flash_mla.h
|
||||
|
||||
#pragma once
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -62,4 +60,4 @@ struct Mla_metadata_params {
|
||||
int num_sm_parts;
|
||||
};
|
||||
|
||||
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream);
|
||||
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream);
|
||||
@ -1,5 +1,3 @@
|
||||
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/named_barrier.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/barrier.h"
|
||||
@ -14,4 +12,4 @@ enum class NamedBarriers {
|
||||
SoftmaxReady = 2,
|
||||
};
|
||||
|
||||
} // flash
|
||||
} // flash
|
||||
@ -194,4 +194,4 @@ struct Softmax {
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace flash
|
||||
} // namespace flash
|
||||
@ -1,5 +1,3 @@
|
||||
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/static_switch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#define CHECK_CUDA(call) \
|
||||
@ -64,4 +62,4 @@
|
||||
} else { \
|
||||
FLASH_ASSERT(false); \
|
||||
} \
|
||||
}()
|
||||
}()
|
||||
@ -235,4 +235,4 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layou
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
} // namespace flash
|
||||
@ -1,79 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Process the input&output data in Flash MLA kernels
|
||||
*/
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
|
||||
template <typename T>
|
||||
struct fill_nan_functor {
|
||||
T* blocked_k;
|
||||
const int* cache_seqlens;
|
||||
int max_seqlen_pad, h_kv, d;
|
||||
|
||||
__host__ __device__
|
||||
fill_nan_functor(T* bk, const int* cs, int msp, int hkv, int dd) :
|
||||
blocked_k(bk), cache_seqlens(cs), max_seqlen_pad(msp), h_kv(hkv), d(dd) {}
|
||||
|
||||
__host__ __device__
|
||||
void operator()(int idx) {
|
||||
auto NAN_VALUE = std::numeric_limits<T>::quiet_NaN();
|
||||
int h_d_size = h_kv * d;
|
||||
int seq_h_d_size = max_seqlen_pad * h_d_size;
|
||||
|
||||
int batch_idx = idx / seq_h_d_size;
|
||||
int pos_in_seq = (idx % seq_h_d_size) / h_d_size;
|
||||
|
||||
if (pos_in_seq >= cache_seqlens[batch_idx]) {
|
||||
blocked_k[idx] = NAN_VALUE; // NaN
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void fill_nan(
|
||||
T* d_blocked_k,
|
||||
const int* d_cache_seqlens,
|
||||
int b, int max_seqlen_pad, int h_kv, int d
|
||||
) {
|
||||
int total_elements = b * max_seqlen_pad * h_kv * d;
|
||||
|
||||
thrust::for_each(
|
||||
thrust::device,
|
||||
thrust::counting_iterator<int>(0),
|
||||
thrust::counting_iterator<int>(total_elements),
|
||||
fill_nan_functor(d_blocked_k, d_cache_seqlens, max_seqlen_pad, h_kv, d)
|
||||
);
|
||||
}
|
||||
6
examples/68_hopper_flash_mla/flash_mla/__init__.py
Normal file
6
examples/68_hopper_flash_mla/flash_mla/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
__version__ = "1.0.0"
|
||||
|
||||
from flash_mla.flash_mla_interface import (
|
||||
get_mla_metadata,
|
||||
flash_mla_with_kvcache,
|
||||
)
|
||||
@ -0,0 +1,67 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import flash_mla_cuda
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
|
||||
Returns:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head dimension of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
|
||||
Returns:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
|
||||
q,
|
||||
k_cache,
|
||||
None,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
)
|
||||
return out, softmax_lse
|
||||
87
examples/68_hopper_flash_mla/setup.py
Normal file
87
examples/68_hopper_flash_mla/setup.py
Normal file
@ -0,0 +1,87 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
from torch.utils.cpp_extension import (
|
||||
BuildExtension,
|
||||
CUDAExtension,
|
||||
IS_WINDOWS,
|
||||
)
|
||||
|
||||
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
|
||||
return nvcc_extra_args + ["--threads", nvcc_threads]
|
||||
|
||||
def get_sources():
|
||||
sources = [
|
||||
"csrc/flash_api.cpp",
|
||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||
"csrc/flash_fwd_mla_metadata.cu",
|
||||
]
|
||||
|
||||
if not DISABLE_FP16:
|
||||
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
|
||||
|
||||
return sources
|
||||
|
||||
def get_features_args():
|
||||
features_args = []
|
||||
if DISABLE_FP16:
|
||||
features_args.append("-DFLASH_MLA_DISABLE_FP16")
|
||||
return features_args
|
||||
|
||||
cc_flag = []
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90a,code=sm_90a")
|
||||
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
if IS_WINDOWS:
|
||||
cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"]
|
||||
else:
|
||||
cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"]
|
||||
|
||||
ext_modules = []
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="flash_mla_cuda",
|
||||
sources=get_sources(),
|
||||
extra_compile_args={
|
||||
"cxx": cxx_args + get_features_args(),
|
||||
"nvcc": append_nvcc_threads(
|
||||
[
|
||||
"-O3",
|
||||
"-std=c++17",
|
||||
"-DNDEBUG",
|
||||
"-D_USE_MATH_DEFINES",
|
||||
"-Wno-deprecated-declarations",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
"--use_fast_math",
|
||||
"--ptxas-options=-v,--register-usage-level=10"
|
||||
]
|
||||
+ cc_flag
|
||||
) + get_features_args(),
|
||||
},
|
||||
include_dirs=[
|
||||
Path(this_dir) / "csrc",
|
||||
Path(this_dir) / ".." / ".." / "include",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
setup(
|
||||
name="flash_mla",
|
||||
version="1.0.0",
|
||||
packages=find_packages(include=['flash_mla']),
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
153
examples/68_hopper_flash_mla/tests/test_flash_mla.py
Normal file
153
examples/68_hopper_flash_mla/tests/test_flash_mla.py
Normal file
@ -0,0 +1,153 @@
|
||||
import argparse
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
||||
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
value = value.float()
|
||||
key = key.repeat_interleave(h_q // h_kv, dim=0)
|
||||
value = value.repeat_interleave(h_q // h_kv, dim=0)
|
||||
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
|
||||
if is_causal:
|
||||
s_q = query.shape[-2]
|
||||
s_k = key.shape[-2]
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(query.dtype)
|
||||
attn_weight += attn_bias
|
||||
lse = attn_weight.logsumexp(dim=-1)
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
return attn_weight @ value, lse
|
||||
|
||||
|
||||
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
||||
amax_diff = (x - y).abs().max().item()
|
||||
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
|
||||
print(
|
||||
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}"
|
||||
)
|
||||
|
||||
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
for i in range(b):
|
||||
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
|
||||
total_seqlens = cache_seqlens.sum().item()
|
||||
mean_seqlens = cache_seqlens.float().mean().int().item()
|
||||
max_seqlen = cache_seqlens.max().item()
|
||||
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
||||
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
|
||||
|
||||
q = torch.randn(b, s_q, h_q, d)
|
||||
block_size = 64
|
||||
block_table = torch.arange(
|
||||
b * max_seqlen_pad // block_size, dtype=torch.int32
|
||||
).view(b, max_seqlen_pad // block_size)
|
||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||
for i in range(b):
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
|
||||
float("nan")
|
||||
)
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens, s_q * h_q // h_kv, h_kv
|
||||
)
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(
|
||||
q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
def ref_mla():
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
begin = i * max_seqlen_pad
|
||||
end = begin + cache_seqlens[i]
|
||||
O, LSE = scaled_dot_product_attention(
|
||||
q[i].transpose(0, 1),
|
||||
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
h_q=h_q,
|
||||
h_kv=h_kv,
|
||||
is_causal=causal,
|
||||
)
|
||||
out[i] = O.transpose(0, 1)
|
||||
lse[i] = LSE
|
||||
return out, lse
|
||||
|
||||
out_flash, lse_flash = flash_mla()
|
||||
out_torch, lse_torch = ref_mla()
|
||||
cal_diff(out_flash, out_torch, "out")
|
||||
cal_diff(lse_flash, lse_torch, "lse")
|
||||
|
||||
t = triton.testing.do_bench(flash_mla)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
|
||||
torch.finfo(q.dtype).bits // 8
|
||||
)
|
||||
print(
|
||||
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
|
||||
)
|
||||
|
||||
|
||||
def main(torch_dtype):
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(torch_dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
h_kv = 1
|
||||
d, dv = 576, 512
|
||||
causal = True
|
||||
|
||||
for b in [128]:
|
||||
for s in [4096, 8192]:
|
||||
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
|
||||
for s_q in [1, 2]: # MTP = 1, 2
|
||||
for varlen in [False, True]:
|
||||
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["bf16", "fp16"],
|
||||
default="bf16",
|
||||
help="Data type to use for testing (bf16 or fp16)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
torch_dtype = torch.bfloat16
|
||||
if args.dtype == "fp16":
|
||||
torch_dtype = torch.float16
|
||||
|
||||
main(torch_dtype)
|
||||
@ -1,131 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
using namespace cutlass;
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
template <class Element_>
|
||||
struct TransposeKernel {
|
||||
using Element = Element_;
|
||||
|
||||
// TODO: use more threads to copy tensor
|
||||
static constexpr int MaxThreadsPerBlock = 1;
|
||||
static constexpr int MinBlocksPerMultiprocessor = 1;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
static constexpr int AlignmentBytes = 16;
|
||||
|
||||
struct SharedStorage {
|
||||
/* empty */
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
// batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size
|
||||
using ProblemShape = cute::tuple<int, int, int, int, int>;
|
||||
|
||||
struct Arguments {
|
||||
const Element *ptrS{nullptr};
|
||||
Element *ptrD{nullptr};
|
||||
ProblemShape problem_shape{};
|
||||
};
|
||||
|
||||
struct Params {
|
||||
const Element *ptrS{nullptr};
|
||||
Element *ptrD{nullptr};
|
||||
ProblemShape problem_shape{};
|
||||
};
|
||||
|
||||
static Params
|
||||
to_underlying_arguments(Arguments const&args, void* workspace) {
|
||||
return Params{
|
||||
args.ptrS,
|
||||
args.ptrD,
|
||||
args.problem_shape
|
||||
};
|
||||
}
|
||||
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
return size_t(0);
|
||||
}
|
||||
|
||||
static Status
|
||||
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
auto [B, S, H, G, D] = params.problem_shape;
|
||||
return dim3(B*S, H, G);
|
||||
}
|
||||
|
||||
static dim3
|
||||
get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
operator()(Params params, [[maybe_unused]] char* smem_buf = nullptr) {
|
||||
auto [B, S, H, G, D] = params.problem_shape;
|
||||
|
||||
// default is column-major layout
|
||||
auto src_layout_ = make_layout(make_shape(D,G,H,S,B));
|
||||
auto src_layout = make_layout(reverse(src_layout_.shape()), reverse(src_layout_.stride()));
|
||||
|
||||
auto dst_layout_ = make_layout(make_shape(D,H,G,S,B));
|
||||
auto dst_layout = make_layout(reverse(dst_layout_.shape()), reverse(dst_layout_.stride()));
|
||||
|
||||
auto src_tensor = make_tensor(
|
||||
make_gmem_ptr(params.ptrS),
|
||||
group<0,2>(src_layout)
|
||||
);
|
||||
|
||||
auto dst_tensor = make_tensor(
|
||||
make_gmem_ptr(params.ptrD),
|
||||
group<0,2>(dst_layout)
|
||||
);
|
||||
|
||||
auto tS = src_tensor(blockIdx.x, blockIdx.y, blockIdx.z, _);
|
||||
auto tD = dst_tensor(blockIdx.x, blockIdx.y, blockIdx.z, _);
|
||||
|
||||
copy(tS, tD);
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user