Support beam search & parallel generation (#7)

This commit is contained in:
Woosuk Kwon
2023-03-10 09:58:21 -08:00
committed by GitHub
parent 04e5acc08e
commit 1a7eb7da61
16 changed files with 660 additions and 161 deletions

View File

@ -1,9 +1,17 @@
#include <torch/extension.h>
#include <map>
#include <vector>
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
void copy_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
void reshape_and_cache(
torch::Tensor& key,
@ -14,7 +22,11 @@ void reshape_and_cache(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"copy_cache_blocks",
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
m.def(
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
m.def(