[Feature] vLLM ARM Enablement for AARCH64 CPUs (#9228)

Signed-off-by: Sanket Kale <sanketk.kale@fujitsu.com>
Co-authored-by: Sanket Kale <sanketk.kale@fujitsu.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Sanket Kale
2024-11-26 08:02:39 +05:30
committed by GitHub
parent 45ac4ff270
commit a6760f6456
9 changed files with 678 additions and 16 deletions

View File

@ -51,6 +51,10 @@ struct KernelVecType<c10::BFloat16> {
using v_load_vec_type = vec_op::BF16Vec16;
};
#else
#ifdef __aarch64__
#ifndef ARM_BF16_SUPPORT
// pass
#else
template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
@ -60,6 +64,18 @@ struct KernelVecType<c10::BFloat16> {
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#endif
#else
template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#endif
#endif
template <typename T>
@ -779,4 +795,4 @@ void paged_attention_v2(
CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl)
});
}
}