Make key optional for rotary embedding (#17566)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin
2025-05-07 00:11:46 -07:00
committed by GitHub
parent 324a3119b0
commit 98c89e16ff
10 changed files with 221 additions and 151 deletions

View File

@ -117,7 +117,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);