Change the name to vLLM (#150)

This commit is contained in:
Woosuk Kwon
2023-06-17 03:07:40 -07:00
committed by GitHub
parent e5464ee484
commit 0b98ba15c7
90 changed files with 342 additions and 339 deletions

View File

@ -46,7 +46,7 @@ void swap_blocks(
}
}
namespace cacheflow {
namespace vllm {
// Grid: (num_layers, num_pairs)
template<typename scalar_t>
@ -77,7 +77,7 @@ __global__ void copy_blocks_kernel(
}
}
} // namespace cacheflow
} // namespace vllm
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
@ -129,7 +129,7 @@ void copy_blocks(
at::ScalarType::Half,
at::ScalarType::BFloat16,
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
cacheflow::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_tensor.data_ptr<int>(),
@ -137,7 +137,7 @@ void copy_blocks(
}));
}
namespace cacheflow {
namespace vllm {
template<typename scalar_t>
__global__ void reshape_and_cache_kernel(
@ -181,7 +181,7 @@ __global__ void reshape_and_cache_kernel(
}
}
} // namespace cacheflow
} // namespace vllm
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
@ -208,7 +208,7 @@ void reshape_and_cache(
key.scalar_type(),
"reshape_and_cache_kernel",
[&] {
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
@ -223,7 +223,7 @@ void reshape_and_cache(
});
}
namespace cacheflow {
namespace vllm {
// Grid: (num_blocks, block_size).
template<typename scalar_t>
@ -343,7 +343,7 @@ __global__ void gather_cached_kv_kernel_optimized(
}
}
} // namespace cacheflow
} // namespace vllm
void gather_cached_kv(
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
@ -370,7 +370,7 @@ void gather_cached_kv(
key.scalar_type(),
"gather_cached_kv_kernel_optimized",
[&] {
cacheflow::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),