[small][batch invariance] Rename the env and internal flags to simplify usage (#26855)
Signed-off-by: Bram Wasti <bwasti@meta.com>
This commit is contained in:
@ -5,11 +5,11 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// vllm_kernel_override_batch_invariant(); returns true
|
||||
// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1
|
||||
inline bool vllm_kernel_override_batch_invariant() {
|
||||
// vllm_is_batch_invariant(); returns true
|
||||
// if env VLLM_BATCH_INVARIANT=1
|
||||
inline bool vllm_is_batch_invariant() {
|
||||
static bool cached = []() {
|
||||
std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT";
|
||||
std::string env_key = "VLLM_BATCH_INVARIANT";
|
||||
const char* val = std::getenv(env_key.c_str());
|
||||
return (val && std::atoi(val) != 0) ? 1 : 0;
|
||||
}();
|
||||
|
||||
@ -426,7 +426,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
wt_ptr % req_alignment_bytes == 0;
|
||||
bool offsets_are_multiple_of_vector_width =
|
||||
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
|
||||
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
|
||||
!batch_invariant_launch) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
@ -474,7 +474,7 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
|
||||
LAUNCH_FUSED_POLY_NORM(8);
|
||||
} else {
|
||||
|
||||
@ -254,7 +254,7 @@ void fused_add_rms_norm_static_fp8_quant(
|
||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||
bool ptrs_are_aligned =
|
||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
|
||||
!batch_invariant_launch) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
|
||||
Reference in New Issue
Block a user