Bugfix - pass 'max_num_tokens_padded' into 'moe_lora_align_block_size' (#27311)
Signed-off-by: gnovack <gnovack@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@ -124,18 +124,14 @@ __global__ void moe_lora_align_sum_kernel(
|
||||
void moe_lora_align_block_size(torch::Tensor topk_ids,
|
||||
torch::Tensor token_lora_mapping,
|
||||
int64_t num_experts, int64_t block_size,
|
||||
int64_t max_loras,
|
||||
int64_t max_loras, int64_t max_num_tokens_padded,
|
||||
int64_t max_num_m_blocks,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const int topk_num = topk_ids.size(1);
|
||||
|
||||
int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1);
|
||||
|
||||
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
|
||||
max_num_tokens_padded = round_to_next_multiple_of(
|
||||
max_num_tokens_padded, static_cast<int>(block_size));
|
||||
int max_num_m_blocks = div_ceil(max_num_tokens_padded, block_size);
|
||||
|
||||
int device_max_shared_mem;
|
||||
auto dev = topk_ids.get_device();
|
||||
|
||||
@ -23,7 +23,8 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch,
|
||||
void moe_lora_align_block_size(torch::Tensor topk_ids,
|
||||
torch::Tensor token_lora_mapping,
|
||||
int64_t num_experts, int64_t block_size,
|
||||
int64_t max_loras,
|
||||
int64_t max_loras, int64_t max_num_tokens_padded,
|
||||
int64_t max_num_m_blocks,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
@ -40,6 +40,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
" Tensor token_lora_mapping,"
|
||||
" int num_experts,"
|
||||
" int block_size, int max_loras, "
|
||||
" int max_num_tokens_padded, "
|
||||
" int max_num_m_blocks, "
|
||||
" Tensor !sorted_token_ids,"
|
||||
" Tensor !experts_ids,"
|
||||
" Tensor !num_tokens_post_pad) -> () ");
|
||||
|
||||
@ -142,6 +142,8 @@ def use_fused_moe_lora_kernel(
|
||||
num_experts,
|
||||
block_size,
|
||||
max_loras,
|
||||
max_num_tokens_padded,
|
||||
max_num_m_blocks,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
|
||||
@ -36,7 +36,7 @@ def test_gptoss20b_lora(gptoss20b_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
max_loras=4,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
|
||||
@ -68,6 +68,8 @@ def test_moe_lora_align_block_size(
|
||||
num_experts,
|
||||
block_size,
|
||||
max_loras,
|
||||
max_num_tokens_padded,
|
||||
max_num_m_blocks,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
|
||||
@ -1801,6 +1801,8 @@ def moe_lora_align_block_size(
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
max_loras: int,
|
||||
max_num_tokens_padded: int,
|
||||
max_num_m_blocks: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
experts_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
@ -1811,6 +1813,8 @@ def moe_lora_align_block_size(
|
||||
num_experts,
|
||||
block_size,
|
||||
max_loras,
|
||||
max_num_tokens_padded,
|
||||
max_num_m_blocks,
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
|
||||
@ -341,6 +341,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
num_experts,
|
||||
block_size,
|
||||
max_loras,
|
||||
max_num_tokens_padded,
|
||||
max_num_m_blocks,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
|
||||
Reference in New Issue
Block a user