[Kernel][Misc] register ops to prevent graph breaks (#6917)
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@ -32,8 +32,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// PagedAttention V2.
|
||||
ops.def(
|
||||
"paged_attention_v2("
|
||||
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
|
||||
" Tensor tmp_out, Tensor query, Tensor key_cache,"
|
||||
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
|
||||
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
@ -122,8 +122,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
|
||||
// Copy the cache blocks from src to dst.
|
||||
cache_ops.def(
|
||||
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
|
||||
"block_mapping) -> ()");
|
||||
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
|
||||
"Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
|
||||
Reference in New Issue
Block a user