[Kernel][Misc] register ops to prevent graph breaks (#6917)

Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
bnellnm
2024-09-11 15:52:19 -04:00
committed by GitHub
parent 7015417fd4
commit 73202dbe77
22 changed files with 528 additions and 102 deletions

View File

@ -32,8 +32,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
@ -122,8 +122,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Copy the cache blocks from src to dst.
cache_ops.def(
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
"block_mapping) -> ()");
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
"Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);
// Reshape the key and value tensors and cache them.