Flash MLA Support - Step 2 (#2134)
* initial commit * initial commit * fix some error * update * bugfix * bugfix * change name * Add input&output process * minor * update * initial commit * initial commit * fix some error * update * bugfix * bugfix * change name * minor * update
This commit is contained in:
@ -34,9 +34,12 @@
|
||||
*/
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
#include "flash_mla.h"
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
|
||||
#include <thrust/universal_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
@ -55,6 +58,11 @@
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
#include "flash_mla.h"
|
||||
#include "fill_nan.h"
|
||||
#include "transform.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CUDA_CHECK(status) \
|
||||
@ -138,14 +146,8 @@ struct Options {
|
||||
return out;
|
||||
}
|
||||
|
||||
/// TOOD:Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
// uint64_t flop = uint64_t(2) * m * n * k;
|
||||
// double gflop = double(flop) / double(1.0e9);
|
||||
// return gflop / runtime_s;
|
||||
}
|
||||
/// TOOD:Compute performance in GFLOP
|
||||
|
||||
};
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
@ -249,20 +251,79 @@ auto initialize_metadata(
|
||||
get_mla_metadata_func(params, stream);
|
||||
}
|
||||
|
||||
// only transpose the dimensions 2 and 3
|
||||
template <class Element>
|
||||
void transpose(
|
||||
thrust::universal_vector<Element> &block_S,
|
||||
thrust::universal_vector<Element> &block_D,
|
||||
cute::tuple<int, int, int, int, int> problem_shape) {
|
||||
|
||||
using Operator = cutlass::transform::device::TransformUniversalAdapter<TransposeKernel<Element>>;
|
||||
|
||||
cudaError_t result;
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the Transpose kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
}
|
||||
|
||||
typename Operator::Arguments arguments{
|
||||
block_S.data().get(),
|
||||
block_D.data().get(),
|
||||
problem_shape,
|
||||
};
|
||||
|
||||
Operator op;
|
||||
|
||||
size_t workspace_size = Operator::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
status = op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
struct TestBed {
|
||||
using Element = cutlass::bfloat16_t;
|
||||
using ElementAcc = float;
|
||||
|
||||
thrust::universal_vector<Element> block_Q; // query
|
||||
thrust::universal_vector<Element> block_K; // blocked key
|
||||
thrust::universal_vector<int32_t> block_T; // block table
|
||||
thrust::universal_vector<int32_t> block_C; // cache seqlens
|
||||
thrust::universal_vector<Element> block_Q; // query
|
||||
thrust::universal_vector<Element> block_Q_T; // query transpose
|
||||
thrust::universal_vector<Element> block_K; // blocked key
|
||||
thrust::universal_vector<int32_t> block_T; // block table
|
||||
thrust::universal_vector<int32_t> block_C; // cache seqlens
|
||||
// TODO: block_V is not used in the example
|
||||
// thrust::universal_vector<Element> block_V; // dv
|
||||
thrust::universal_vector<int32_t> block_MD; // mla metadata
|
||||
thrust::universal_vector<int32_t> block_S; // num splits
|
||||
thrust::universal_vector<Element> block_O; // output
|
||||
thrust::universal_vector<Element> block_LSE; // lse
|
||||
thrust::universal_vector<int32_t> block_MD; // mla metadata
|
||||
thrust::universal_vector<int32_t> block_S; // num splits
|
||||
thrust::universal_vector<Element> block_O; // output
|
||||
thrust::universal_vector<Element> block_LSE; // lse
|
||||
thrust::universal_vector<Element> block_O_T; // output transpose
|
||||
thrust::universal_vector<Element> block_LSE_T; // lse transpose
|
||||
thrust::universal_vector<ElementAcc> block_O_Accum; // output
|
||||
thrust::universal_vector<ElementAcc> block_LSE_Accum; // lse
|
||||
|
||||
@ -289,6 +350,7 @@ struct TestBed {
|
||||
|
||||
// Query: [b, s_q, h_q, d]
|
||||
block_Q.resize(options.b * options.s_q * options.h_q * options.d);
|
||||
block_Q_T.resize(options.b * options.s_q * options.h_q * options.d);
|
||||
|
||||
// Block table: [b, max_num_blocks_per_seq]
|
||||
block_T.resize(total_blocks);
|
||||
@ -300,7 +362,9 @@ struct TestBed {
|
||||
initialize_values(block_T, cutlass::Distribution::Sequential, seed + 3);
|
||||
initialize_values(block_K, cutlass::Distribution::Gaussian, seed + 5);
|
||||
|
||||
// TODO: Set the exceeding part to NaN
|
||||
// Set the exceeding part to NaN
|
||||
fill_nan(block_K.data().get(), block_C.data().get(),
|
||||
options.b, max_seqlen_pad, options.h_kv, options.d);
|
||||
|
||||
initialize_metadata(block_C, block_MD, block_S, num_sm_parts, options);
|
||||
|
||||
@ -310,9 +374,10 @@ struct TestBed {
|
||||
|
||||
// LSE: [batch_size, num_heads, seqlen_q]
|
||||
block_LSE.resize(options.b * num_heads * seqlen_q);
|
||||
|
||||
block_LSE_T.resize(options.b * seqlen_q * num_heads);
|
||||
// Output: [batch_size, seqlen_q, num_heads, head_size_v]
|
||||
block_O.resize(options.b * seqlen_q * num_heads * options.dv);
|
||||
block_O_T.resize(options.b * seqlen_q * num_heads * options.dv);
|
||||
|
||||
auto softmax_lse_size = (options.b + num_sm_parts) * num_heads * seqlen_q;
|
||||
auto out_accum_size = (options.b + num_sm_parts) * num_heads * seqlen_q * options.dv;
|
||||
@ -355,8 +420,10 @@ struct TestBed {
|
||||
int seqlen_q = seqlen_q_ori * ngroups;
|
||||
int num_heads = num_heads_k;
|
||||
|
||||
// TODO: preprocess the query
|
||||
// q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3).reshape({batch_size, seqlen_q, num_heads, head_size});
|
||||
// preprocess the query
|
||||
transpose(
|
||||
block_Q, block_Q_T,
|
||||
cute::make_shape(options.b, seqlen_q_ori, num_heads_k, ngroups, options.d));
|
||||
|
||||
cudaStream_t stream{nullptr};
|
||||
|
||||
@ -371,7 +438,7 @@ struct TestBed {
|
||||
kernel_params.h_h_k_ratio = num_heads_ori / num_heads_k;
|
||||
kernel_params.ngroups = ngroups;
|
||||
|
||||
kernel_params.q_ptr = block_Q.data().get();
|
||||
kernel_params.q_ptr = block_Q_T.data().get();
|
||||
kernel_params.k_ptr = block_K.data().get();
|
||||
// TODO: block_V is not used in the example
|
||||
kernel_params.v_ptr = block_K.data().get();
|
||||
@ -416,11 +483,12 @@ struct TestBed {
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// TODO: postprocess the output
|
||||
// out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
||||
// .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
||||
// softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
|
||||
// .reshape({batch_size, num_heads_ori, seqlen_q_ori});
|
||||
transpose(
|
||||
block_O, block_O_T,
|
||||
cute::make_shape(options.b, seqlen_q_ori, ngroups, num_heads_k, options.dv));
|
||||
transpose(
|
||||
block_LSE, block_LSE_T,
|
||||
cute::make_shape(options.b, num_heads_k, seqlen_q_ori, ngroups, 1));
|
||||
|
||||
// TODO: reference check
|
||||
|
||||
@ -433,9 +501,8 @@ struct TestBed {
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 100a.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 7)) {
|
||||
std::cerr << "This example requires CUDA 12.7 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
79
examples/68_hopper_flash_mla/fill_nan.h
Normal file
79
examples/68_hopper_flash_mla/fill_nan.h
Normal file
@ -0,0 +1,79 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Process the input&output data in Flash MLA kernels
|
||||
*/
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
|
||||
template <typename T>
|
||||
struct fill_nan_functor {
|
||||
T* blocked_k;
|
||||
const int* cache_seqlens;
|
||||
int max_seqlen_pad, h_kv, d;
|
||||
|
||||
__host__ __device__
|
||||
fill_nan_functor(T* bk, const int* cs, int msp, int hkv, int dd) :
|
||||
blocked_k(bk), cache_seqlens(cs), max_seqlen_pad(msp), h_kv(hkv), d(dd) {}
|
||||
|
||||
__host__ __device__
|
||||
void operator()(int idx) {
|
||||
auto NAN_VALUE = std::numeric_limits<T>::quiet_NaN();
|
||||
int h_d_size = h_kv * d;
|
||||
int seq_h_d_size = max_seqlen_pad * h_d_size;
|
||||
|
||||
int batch_idx = idx / seq_h_d_size;
|
||||
int pos_in_seq = (idx % seq_h_d_size) / h_d_size;
|
||||
|
||||
if (pos_in_seq >= cache_seqlens[batch_idx]) {
|
||||
blocked_k[idx] = NAN_VALUE; // NaN
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void fill_nan(
|
||||
T* d_blocked_k,
|
||||
const int* d_cache_seqlens,
|
||||
int b, int max_seqlen_pad, int h_kv, int d
|
||||
) {
|
||||
int total_elements = b * max_seqlen_pad * h_kv * d;
|
||||
|
||||
thrust::for_each(
|
||||
thrust::device,
|
||||
thrust::counting_iterator<int>(0),
|
||||
thrust::counting_iterator<int>(total_elements),
|
||||
fill_nan_functor(d_blocked_k, d_cache_seqlens, max_seqlen_pad, h_kv, d)
|
||||
);
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/softmax.h
|
||||
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
131
examples/68_hopper_flash_mla/transform.h
Normal file
131
examples/68_hopper_flash_mla/transform.h
Normal file
@ -0,0 +1,131 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
using namespace cutlass;
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
template <class Element_>
|
||||
struct TransposeKernel {
|
||||
using Element = Element_;
|
||||
|
||||
// TODO: use more threads to copy tensor
|
||||
static constexpr int MaxThreadsPerBlock = 1;
|
||||
static constexpr int MinBlocksPerMultiprocessor = 1;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
static constexpr int AlignmentBytes = 16;
|
||||
|
||||
struct SharedStorage {
|
||||
/* empty */
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
// batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size
|
||||
using ProblemShape = cute::tuple<int, int, int, int, int>;
|
||||
|
||||
struct Arguments {
|
||||
const Element *ptrS{nullptr};
|
||||
Element *ptrD{nullptr};
|
||||
ProblemShape problem_shape{};
|
||||
};
|
||||
|
||||
struct Params {
|
||||
const Element *ptrS{nullptr};
|
||||
Element *ptrD{nullptr};
|
||||
ProblemShape problem_shape{};
|
||||
};
|
||||
|
||||
static Params
|
||||
to_underlying_arguments(Arguments const&args, void* workspace) {
|
||||
return Params{
|
||||
args.ptrS,
|
||||
args.ptrD,
|
||||
args.problem_shape
|
||||
};
|
||||
}
|
||||
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
return size_t(0);
|
||||
}
|
||||
|
||||
static Status
|
||||
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
auto [B, S, H, G, D] = params.problem_shape;
|
||||
return dim3(B*S, H, G);
|
||||
}
|
||||
|
||||
static dim3
|
||||
get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
operator()(Params params, [[maybe_unused]] char* smem_buf = nullptr) {
|
||||
auto [B, S, H, G, D] = params.problem_shape;
|
||||
|
||||
// default is column-major layout
|
||||
auto src_layout_ = make_layout(make_shape(D,G,H,S,B));
|
||||
auto src_layout = make_layout(reverse(src_layout_.shape()), reverse(src_layout_.stride()));
|
||||
|
||||
auto dst_layout_ = make_layout(make_shape(D,H,G,S,B));
|
||||
auto dst_layout = make_layout(reverse(dst_layout_.shape()), reverse(dst_layout_.stride()));
|
||||
|
||||
auto src_tensor = make_tensor(
|
||||
make_gmem_ptr(params.ptrS),
|
||||
group<0,2>(src_layout)
|
||||
);
|
||||
|
||||
auto dst_tensor = make_tensor(
|
||||
make_gmem_ptr(params.ptrD),
|
||||
group<0,2>(dst_layout)
|
||||
);
|
||||
|
||||
auto tS = src_tensor(blockIdx.x, blockIdx.y, blockIdx.z, _);
|
||||
auto tD = dst_tensor(blockIdx.x, blockIdx.y, blockIdx.z, _);
|
||||
|
||||
copy(tS, tD);
|
||||
}
|
||||
};
|
||||
@ -1,4 +1,4 @@
|
||||
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/utils.h
|
||||
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
Reference in New Issue
Block a user