Add miscellaneous updates (#8)
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <algorithm>
|
||||
@ -73,6 +72,8 @@ void copy_blocks(
|
||||
}
|
||||
}
|
||||
|
||||
namespace cacheflow {
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void reshape_and_cache_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
@ -112,6 +113,8 @@ __global__ void reshape_and_cache_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
@ -131,7 +134,7 @@ void reshape_and_cache(
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_kernel",
|
||||
[&] {
|
||||
reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
|
||||
Reference in New Issue
Block a user