Implement custom kernel for LLaMA rotary embedding (#14)

This commit is contained in:
Woosuk Kwon
2023-03-30 11:04:21 -07:00
committed by GitHub
parent 80a2f812f1
commit 88c0268a18
10 changed files with 318 additions and 69 deletions

16
csrc/pos_encoding.cpp Normal file
View File

@ -0,0 +1,16 @@
#include <torch/extension.h>
void rotary_embedding_neox(
torch::Tensor& out_query,
torch::Tensor& out_key,
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
torch::Tensor& cos_sin_cache);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rotary_embedding_neox",
&rotary_embedding_neox,
"Apply GPT-NeoX style rotary embedding to query and key");
}