Bugfix - pass 'max_num_tokens_padded' into 'moe_lora_align_block_size' (#27311)

Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
gnovack
2025-10-22 05:23:57 -07:00
committed by GitHub
parent 1a0f4defb7
commit 8e4ca4d14e
8 changed files with 17 additions and 8 deletions

View File

@ -124,18 +124,14 @@ __global__ void moe_lora_align_sum_kernel(
void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras,
int64_t max_loras, int64_t max_num_tokens_padded,
int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad) {
const int topk_num = topk_ids.size(1);
int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1);
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
max_num_tokens_padded = round_to_next_multiple_of(
max_num_tokens_padded, static_cast<int>(block_size));
int max_num_m_blocks = div_ceil(max_num_tokens_padded, block_size);
int device_max_shared_mem;
auto dev = topk_ids.get_device();

View File

@ -23,7 +23,8 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch,
void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras,
int64_t max_loras, int64_t max_num_tokens_padded,
int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);

View File

@ -40,6 +40,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor token_lora_mapping,"
" int num_experts,"
" int block_size, int max_loras, "
" int max_num_tokens_padded, "
" int max_num_m_blocks, "
" Tensor !sorted_token_ids,"
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad) -> () ");

View File

@ -142,6 +142,8 @@ def use_fused_moe_lora_kernel(
num_experts,
block_size,
max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,

View File

@ -36,7 +36,7 @@ def test_gptoss20b_lora(gptoss20b_lora_files):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_loras=1,
max_loras=4,
trust_remote_code=True,
)

View File

@ -68,6 +68,8 @@ def test_moe_lora_align_block_size(
num_experts,
block_size,
max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_token_ids,
expert_ids,
num_tokens_post_pad,

View File

@ -1801,6 +1801,8 @@ def moe_lora_align_block_size(
num_experts: int,
block_size: int,
max_loras: int,
max_num_tokens_padded: int,
max_num_m_blocks: int,
sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
@ -1811,6 +1813,8 @@ def moe_lora_align_block_size(
num_experts,
block_size,
max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,

View File

@ -341,6 +341,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
num_experts,
block_size,
max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_ids,
expert_ids,
num_tokens_post_pad,