Flash MLA Support - Step 2 (#2134)

* initial commit

* initial commit

* fix some error

* update

* bugfix

* bugfix

* change name

* Add input&output process

* minor

* update

* initial commit

* initial commit

* fix some error

* update

* bugfix

* bugfix

* change name

* minor

* update
This commit is contained in:
myu-guo
2025-02-26 12:18:03 +08:00
committed by GitHub
parent 415d587ebf
commit af5519d938
5 changed files with 310 additions and 33 deletions

View File

@ -34,9 +34,12 @@
*/
#include <cassert>
#include <iostream>
#include "flash_fwd_mla_kernel.h"
#include "flash_mla.h"
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include <thrust/universal_vector.h>
#include <thrust/device_vector.h>
@ -55,6 +58,11 @@
#include <cuda_runtime.h>
#include "flash_fwd_mla_kernel.h"
#include "flash_mla.h"
#include "fill_nan.h"
#include "transform.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#define CUDA_CHECK(status) \
@ -138,14 +146,8 @@ struct Options {
return out;
}
/// TOOD:Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
// uint64_t flop = uint64_t(2) * m * n * k;
// double gflop = double(flop) / double(1.0e9);
// return gflop / runtime_s;
}
/// TOOD:Compute performance in GFLOP
};
/// Helper to initialize a block of device data
@ -249,20 +251,79 @@ auto initialize_metadata(
get_mla_metadata_func(params, stream);
}
// only transpose the dimensions 2 and 3
template <class Element>
void transpose(
thrust::universal_vector<Element> &block_S,
thrust::universal_vector<Element> &block_D,
cute::tuple<int, int, int, int, int> problem_shape) {
using Operator = cutlass::transform::device::TransformUniversalAdapter<TransposeKernel<Element>>;
cudaError_t result;
result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the Transpose kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
}
typename Operator::Arguments arguments{
block_S.data().get(),
block_D.data().get(),
problem_shape,
};
Operator op;
size_t workspace_size = Operator::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "This kernel is not supported. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return;
}
status = op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return;
}
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return;
}
result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return;
}
}
struct TestBed {
using Element = cutlass::bfloat16_t;
using ElementAcc = float;
thrust::universal_vector<Element> block_Q; // query
thrust::universal_vector<Element> block_K; // blocked key
thrust::universal_vector<int32_t> block_T; // block table
thrust::universal_vector<int32_t> block_C; // cache seqlens
thrust::universal_vector<Element> block_Q; // query
thrust::universal_vector<Element> block_Q_T; // query transpose
thrust::universal_vector<Element> block_K; // blocked key
thrust::universal_vector<int32_t> block_T; // block table
thrust::universal_vector<int32_t> block_C; // cache seqlens
// TODO: block_V is not used in the example
// thrust::universal_vector<Element> block_V; // dv
thrust::universal_vector<int32_t> block_MD; // mla metadata
thrust::universal_vector<int32_t> block_S; // num splits
thrust::universal_vector<Element> block_O; // output
thrust::universal_vector<Element> block_LSE; // lse
thrust::universal_vector<int32_t> block_MD; // mla metadata
thrust::universal_vector<int32_t> block_S; // num splits
thrust::universal_vector<Element> block_O; // output
thrust::universal_vector<Element> block_LSE; // lse
thrust::universal_vector<Element> block_O_T; // output transpose
thrust::universal_vector<Element> block_LSE_T; // lse transpose
thrust::universal_vector<ElementAcc> block_O_Accum; // output
thrust::universal_vector<ElementAcc> block_LSE_Accum; // lse
@ -289,6 +350,7 @@ struct TestBed {
// Query: [b, s_q, h_q, d]
block_Q.resize(options.b * options.s_q * options.h_q * options.d);
block_Q_T.resize(options.b * options.s_q * options.h_q * options.d);
// Block table: [b, max_num_blocks_per_seq]
block_T.resize(total_blocks);
@ -300,7 +362,9 @@ struct TestBed {
initialize_values(block_T, cutlass::Distribution::Sequential, seed + 3);
initialize_values(block_K, cutlass::Distribution::Gaussian, seed + 5);
// TODO: Set the exceeding part to NaN
// Set the exceeding part to NaN
fill_nan(block_K.data().get(), block_C.data().get(),
options.b, max_seqlen_pad, options.h_kv, options.d);
initialize_metadata(block_C, block_MD, block_S, num_sm_parts, options);
@ -310,9 +374,10 @@ struct TestBed {
// LSE: [batch_size, num_heads, seqlen_q]
block_LSE.resize(options.b * num_heads * seqlen_q);
block_LSE_T.resize(options.b * seqlen_q * num_heads);
// Output: [batch_size, seqlen_q, num_heads, head_size_v]
block_O.resize(options.b * seqlen_q * num_heads * options.dv);
block_O_T.resize(options.b * seqlen_q * num_heads * options.dv);
auto softmax_lse_size = (options.b + num_sm_parts) * num_heads * seqlen_q;
auto out_accum_size = (options.b + num_sm_parts) * num_heads * seqlen_q * options.dv;
@ -355,8 +420,10 @@ struct TestBed {
int seqlen_q = seqlen_q_ori * ngroups;
int num_heads = num_heads_k;
// TODO: preprocess the query
// q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3).reshape({batch_size, seqlen_q, num_heads, head_size});
// preprocess the query
transpose(
block_Q, block_Q_T,
cute::make_shape(options.b, seqlen_q_ori, num_heads_k, ngroups, options.d));
cudaStream_t stream{nullptr};
@ -371,7 +438,7 @@ struct TestBed {
kernel_params.h_h_k_ratio = num_heads_ori / num_heads_k;
kernel_params.ngroups = ngroups;
kernel_params.q_ptr = block_Q.data().get();
kernel_params.q_ptr = block_Q_T.data().get();
kernel_params.k_ptr = block_K.data().get();
// TODO: block_V is not used in the example
kernel_params.v_ptr = block_K.data().get();
@ -416,11 +483,12 @@ struct TestBed {
CUDA_CHECK(cudaDeviceSynchronize());
// TODO: postprocess the output
// out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
// .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
// softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
// .reshape({batch_size, num_heads_ori, seqlen_q_ori});
transpose(
block_O, block_O_T,
cute::make_shape(options.b, seqlen_q_ori, ngroups, num_heads_k, options.dv));
transpose(
block_LSE, block_LSE_T,
cute::make_shape(options.b, num_heads_k, seqlen_q_ori, ngroups, 1));
// TODO: reference check
@ -433,9 +501,8 @@ struct TestBed {
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least 100a.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 7)) {
std::cerr << "This example requires CUDA 12.7 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}

View File

@ -0,0 +1,79 @@
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Process the input&output data in Flash MLA kernels
*/
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/counting_iterator.h>
template <typename T>
struct fill_nan_functor {
T* blocked_k;
const int* cache_seqlens;
int max_seqlen_pad, h_kv, d;
__host__ __device__
fill_nan_functor(T* bk, const int* cs, int msp, int hkv, int dd) :
blocked_k(bk), cache_seqlens(cs), max_seqlen_pad(msp), h_kv(hkv), d(dd) {}
__host__ __device__
void operator()(int idx) {
auto NAN_VALUE = std::numeric_limits<T>::quiet_NaN();
int h_d_size = h_kv * d;
int seq_h_d_size = max_seqlen_pad * h_d_size;
int batch_idx = idx / seq_h_d_size;
int pos_in_seq = (idx % seq_h_d_size) / h_d_size;
if (pos_in_seq >= cache_seqlens[batch_idx]) {
blocked_k[idx] = NAN_VALUE; // NaN
}
}
};
template <typename T>
void fill_nan(
T* d_blocked_k,
const int* d_cache_seqlens,
int b, int max_seqlen_pad, int h_kv, int d
) {
int total_elements = b * max_seqlen_pad * h_kv * d;
thrust::for_each(
thrust::device,
thrust::counting_iterator<int>(0),
thrust::counting_iterator<int>(total_elements),
fill_nan_functor(d_blocked_k, d_cache_seqlens, max_seqlen_pad, h_kv, d)
);
}

View File

@ -1,4 +1,4 @@
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/softmax.h
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
#pragma once

View File

@ -0,0 +1,131 @@
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
using namespace cutlass;
#include "cutlass/arch/arch.h"
template <class Element_>
struct TransposeKernel {
using Element = Element_;
// TODO: use more threads to copy tensor
static constexpr int MaxThreadsPerBlock = 1;
static constexpr int MinBlocksPerMultiprocessor = 1;
using ArchTag = cutlass::arch::Sm90;
static constexpr int AlignmentBytes = 16;
struct SharedStorage {
/* empty */
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
// batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size
using ProblemShape = cute::tuple<int, int, int, int, int>;
struct Arguments {
const Element *ptrS{nullptr};
Element *ptrD{nullptr};
ProblemShape problem_shape{};
};
struct Params {
const Element *ptrS{nullptr};
Element *ptrD{nullptr};
ProblemShape problem_shape{};
};
static Params
to_underlying_arguments(Arguments const&args, void* workspace) {
return Params{
args.ptrS,
args.ptrD,
args.problem_shape
};
}
static Status
can_implement(Arguments const& args) {
return Status::kSuccess;
}
static size_t
get_workspace_size(Arguments const& args) {
return size_t(0);
}
static Status
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr) {
return Status::kSuccess;
}
static dim3
get_grid_shape(Params const& params) {
auto [B, S, H, G, D] = params.problem_shape;
return dim3(B*S, H, G);
}
static dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void
operator()(Params params, [[maybe_unused]] char* smem_buf = nullptr) {
auto [B, S, H, G, D] = params.problem_shape;
// default is column-major layout
auto src_layout_ = make_layout(make_shape(D,G,H,S,B));
auto src_layout = make_layout(reverse(src_layout_.shape()), reverse(src_layout_.stride()));
auto dst_layout_ = make_layout(make_shape(D,H,G,S,B));
auto dst_layout = make_layout(reverse(dst_layout_.shape()), reverse(dst_layout_.stride()));
auto src_tensor = make_tensor(
make_gmem_ptr(params.ptrS),
group<0,2>(src_layout)
);
auto dst_tensor = make_tensor(
make_gmem_ptr(params.ptrD),
group<0,2>(dst_layout)
);
auto tS = src_tensor(blockIdx.x, blockIdx.y, blockIdx.z, _);
auto tD = dst_tensor(blockIdx.x, blockIdx.y, blockIdx.z, _);
copy(tS, tD);
}
};

View File

@ -1,4 +1,4 @@
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/utils.h
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h
#pragma once