Add miscellaneous updates (#8)

This commit is contained in:
Woosuk Kwon
2023-03-13 13:48:38 -07:00
committed by GitHub
parent e9d3f2ff77
commit cfae35b861
7 changed files with 44 additions and 22 deletions

View File

@ -1,5 +1,4 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <algorithm>
@ -73,6 +72,8 @@ void copy_blocks(
}
}
namespace cacheflow {
template<typename scalar_t>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
@ -112,6 +113,8 @@ __global__ void reshape_and_cache_kernel(
}
}
} // namespace cacheflow
void reshape_and_cache(
torch::Tensor& key,
torch::Tensor& value,
@ -131,7 +134,7 @@ void reshape_and_cache(
key.scalar_type(),
"reshape_and_cache_kernel",
[&] {
reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),