[Perf] Optimize reshape_and_cache_flash CUDA Kernel (#22036)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-08-01 19:18:51 -04:00
committed by GitHub
parent 88faa466d7
commit eefbf4a68b
2 changed files with 225 additions and 23 deletions

View File

@ -5,6 +5,7 @@
#include "cuda_utils.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/vectorization_utils.cuh"
#ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
@ -261,14 +262,26 @@ __global__ void reshape_and_cache_kernel(
}
}
// Used by vectorization_utils to copy/convert one element
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
struct CopyWithScaleOp {
float scale;
__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst = static_cast<OutT>(src);
} else {
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
}
}
};
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
// head_size]
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
// head_size]
cache_t* __restrict__ key_cache, // NHD or HND, shape see comments below
cache_t* __restrict__ value_cache, // same above
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int64_t block_stride, const int64_t page_stride,
const int64_t head_stride, const int64_t key_stride,
@ -282,25 +295,58 @@ __global__ void reshape_and_cache_flash_kernel(
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_key_value_idx = block_idx * block_stride +
block_offset * page_stride +
head_idx * head_stride + head_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_value_idx] = tgt_key;
value_cache[tgt_key_value_idx] = tgt_value;
} else {
key_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
const int n_elems = num_heads * head_size;
// pointers to the beginning of the source row for this token.
const scalar_t* __restrict__ key_src = key + token_idx * key_stride;
const scalar_t* __restrict__ value_src = value + token_idx * value_stride;
// find the start position inside the kv-cache for this token.
cache_t* __restrict__ key_dst =
key_cache + block_idx * block_stride + block_offset * page_stride;
cache_t* __restrict__ value_dst =
value_cache + block_idx * block_stride + block_offset * page_stride;
// this is true for the NHD layout where `head_stride == head_size`
const bool is_contiguous_heads = (head_stride == head_size);
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
if (is_contiguous_heads) {
// NHD layout
// kv cache: [num_blocks, block_size, num_heads, head_size]
vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, n_elems, threadIdx.x,
blockDim.x, k_op);
vectorize_with_alignment<VEC_SIZE>(value_src, value_dst, n_elems,
threadIdx.x, blockDim.x, v_op);
} else {
// HND layout: heads are strided, but each head_size segment is contiguous
// kv cache: [num_blocks, num_heads, block_size, head_size]
const int lane = threadIdx.x & 31; // 0..31 within warp
const int warp_id = threadIdx.x >> 5; // warp index within block
const int warps_per_block = blockDim.x >> 5;
for (int head = warp_id; head < num_heads; head += warps_per_block) {
const scalar_t* __restrict__ k_src_h = key_src + head * head_size;
const scalar_t* __restrict__ v_src_h = value_src + head * head_size;
cache_t* __restrict__ k_dst_h =
key_dst + static_cast<int64_t>(head) * head_stride;
cache_t* __restrict__ v_dst_h =
value_dst + static_cast<int64_t>(head) * head_stride;
// within each head, let the 32 threads of the warp perform the vector
// copy
vectorize_with_alignment<VEC_SIZE>(k_src_h, k_dst_h, head_size, lane, 32,
k_op);
vectorize_with_alignment<VEC_SIZE>(v_src_h, v_dst_h, head_size, lane, 32,
v_op);
}
}
}