[FP8][Kernel] Dynamic kv cache scaling factors computation (#11906)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
committed by
GitHub
parent
6e650f56a1
commit
e97f802b2d
@ -218,7 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
|
||||
int max_ctx_blocks, float k_scale, float v_scale) {
|
||||
int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr) {
|
||||
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
const int laneid = threadIdx.x % WARP_SIZE;
|
||||
@ -406,7 +406,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
|
||||
const _B8x8 Vlocalb8 = v_ptrh8be[d];
|
||||
Vlocal[h][b * BLOCK_SIZE / 8 + d] =
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, *v_scale_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -416,7 +416,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
#pragma unroll
|
||||
for (int d = 0; d < KHELOOP; d++) {
|
||||
Klocal[d] =
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], *k_scale_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
@ -890,7 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
|
||||
int max_ctx_blocks, float k_scale, float v_scale) {
|
||||
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
|
||||
@ -919,7 +919,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
|
||||
k_scale, v_scale);
|
||||
k_scale_ptr, v_scale_ptr);
|
||||
|
||||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 512>
|
||||
@ -929,7 +929,7 @@ void paged_attention_custom_launcher(
|
||||
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||
int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
float k_scale, float v_scale) {
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -953,6 +953,8 @@ void paged_attention_custom_launcher(
|
||||
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
|
||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
|
||||
const int max_num_partitions =
|
||||
@ -1087,7 +1089,8 @@ void paged_attention(
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale) {
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
const int head_size = query.size(2);
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
|
||||
@ -10,5 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||
torch::Tensor& context_lens, int64_t block_size,
|
||||
int64_t max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale,
|
||||
double v_scale);
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale);
|
||||
|
||||
@ -27,7 +27,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
" int max_context_len,"
|
||||
" Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype,"
|
||||
" float k_scale, float v_scale) -> ()");
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user