diff --git a/benchmarks/kernels/benchmark_moe_align_block_size.py b/benchmarks/kernels/benchmark_moe_align_block_size.py new file mode 100644 index 0000000000..024a5dcfc8 --- /dev/null +++ b/benchmarks/kernels/benchmark_moe_align_block_size.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import itertools + +import torch +import triton + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size_triton, +) + + +def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: + return torch.stack( + [ + torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] + for _ in range(num_tokens) + ] + ) + + +def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8): + """ + Verifies vllm vs. Triton + """ + topk_ids = get_topk_ids(num_tokens, num_experts, topk) + + # 1. malloc space for triton and vllm + # malloc enough space (max_num_tokens_padded) for the sorted ids + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids_triton = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device="cuda" + ) + sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value + expert_ids_triton = torch.zeros( + (max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda" + ) + num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda") + + sorted_ids_vllm = torch.empty_like(sorted_ids_triton) + sorted_ids_vllm.fill_(topk_ids.numel()) + expert_ids_vllm = torch.zeros_like(expert_ids_triton) + num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton) + + # 2. run implementations + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids_triton, + expert_ids_triton, + num_tokens_post_pad_triton, + ) + + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids_vllm, + expert_ids_vllm, + num_tokens_post_pad_vllm, + ) + print(f"✅ VLLM implementation works with {num_experts} experts!") + + # 3. compare results + if torch.allclose(expert_ids_triton, expert_ids_vllm) and torch.allclose( + num_tokens_post_pad_triton, num_tokens_post_pad_vllm + ): + print("✅ Triton and VLLM implementations match.") + else: + print("❌ Triton and VLLM implementations DO NOT match.") + print("Triton expert_ids:", expert_ids_triton) + print("VLLM expert_ids:", expert_ids_vllm) + print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton) + print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm) + + +# test configurations +num_tokens_range = [1, 16, 256, 4096] +num_experts_range = [16, 64, 224, 256, 280, 512] +topk_range = [1, 2, 8] +configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens", "num_experts", "topk"], + x_vals=configs, + line_arg="provider", + line_vals=["vllm", "triton"], # "triton" + line_names=["VLLM", "Triton"], # "Triton" + plot_name="moe-align-block-size-performance", + args={}, + ) +) +def benchmark(num_tokens, num_experts, topk, provider): + """Benchmark function for Triton.""" + block_size = 256 + topk_ids = get_topk_ids(num_tokens, num_experts, topk) + + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") + num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda") + + quantiles = [0.5, 0.2, 0.8] + + if provider == "vllm": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids.clone(), + expert_ids.clone(), + num_tokens_post_pad.clone(), + ), + quantiles=quantiles, + ) + elif provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids.clone(), + expert_ids.clone(), + num_tokens_post_pad.clone(), + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--num_experts", + type=int, + default=64, + choices=[8, 16, 32, 64, 128, 256], + ) + parser.add_argument( + "--topk", + type=int, + default=8, + choices=[2, 4, 8], + help="Top-k value for correctness check.", + ) + args = parser.parse_args() + + print("Running correctness check...") + check_correctness(num_tokens=1024, num_experts=args.num_experts, topk=args.topk) + benchmark.run(print_data=True, show_plots=True) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 6b6a9d04a6..9335e2333b 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -13,232 +13,45 @@ namespace vllm { namespace moe { -namespace { -__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, - int32_t col) { - // don't worry about overflow because num_experts is relatively small - return row * total_col + col; -} -} // namespace - -template -__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, - int32_t* sorted_token_ids, - int32_t* expert_ids, - int32_t* total_tokens_post_pad, - int32_t num_experts, - int32_t block_size, size_t numel) { - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; - - extern __shared__ int32_t shared_mem[]; - int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1) - token_cnts_t* tokens_cnts = - (token_cnts_t*)(shared_mem + num_experts + - 1); // 2d tensor with shape (blockDim.x + 1, num_experts) - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } - - /** - * In the first step we compute token_cnts[thread_index + 1][expert_index], - * which counts how many tokens in the token shard of thread_index are - * assigned to expert expert_index. - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - if (threadIdx.x < num_experts) { - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += - tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; - } - } - - __syncthreads(); - - // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i - 1] + - CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], - block_size) * - block_size; - } - *total_tokens_post_pad = static_cast(cumsum[num_experts]); - } - - __syncthreads(); - - /** - * For each expert, each thread processes the tokens of the corresponding - * blocks and stores the corresponding expert_id for each block. - */ - if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[i / block_size] = threadIdx.x; - } - } - - /** - * Each thread processes a token shard, calculating the index of each token - * after sorting by expert number. Given the example topk_ids = - * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, - * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a - * padding value(preset in python). - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and - * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens - * processed by the expert with expert_id within the current thread's token - * shard. - */ - int32_t rank_post_pad = - tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + - cumsum[expert_id]; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; - } -} - -// TODO(simon): this is temporarily adapted from -// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7 -// we did this to unblock Deepseek V3 but there should be a better -// implementation to manage shared memory. template -__global__ void moe_align_block_size_global_mem_kernel( - scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, - int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) { - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; +__global__ void moe_align_block_size_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, + int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, + size_t numel, int32_t* __restrict__ cumsum) { + extern __shared__ int32_t shared_counts[]; - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } - - /** - * In the first step we compute token_cnts[thread_index + 1][expert_index], - * which counts how many tokens in the token shard of thread_index are - * assigned to expert expert_index. - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - if (threadIdx.x < num_experts) { - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += - tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; - } - } - - __syncthreads(); - - // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i - 1] + - CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], - block_size) * - block_size; - } - *total_tokens_post_pad = cumsum[num_experts]; - } - - __syncthreads(); - - /** - * For each expert, each thread processes the tokens of the corresponding - * blocks and stores the corresponding expert_id for each block. - */ - if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[i / block_size] = threadIdx.x; - } - } - - /** - * Each thread processes a token shard, calculating the index of each token - * after sorting by expert number. Given the example topk_ids = - * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, - * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a - * padding value(preset in python). - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and - * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens - * processed by the expert with expert_id within the current thread's token - * shard. - */ - int32_t rank_post_pad = - tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + - cumsum[expert_id]; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; - } -} - -// taken from -// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 -template -__global__ void sgl_moe_align_block_size_kernel( - scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, - int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel, int32_t* cumsum) { - __shared__ int32_t shared_counts[32][8]; - - const int warp_id = threadIdx.x / 32; - const int experts_per_warp = 8; + const int warp_id = threadIdx.x / WARP_SIZE; const int my_expert_start = warp_id * experts_per_warp; - // Initialize shared_counts for this warp's experts for (int i = 0; i < experts_per_warp; ++i) { - if (my_expert_start + i < num_experts) { - shared_counts[warp_id][i] = 0; + if (my_expert_start + i < padded_num_experts) { + shared_counts[warp_id * experts_per_warp + i] = 0; } } __syncthreads(); - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + for (size_t i = tid; i < numel; i += stride) { int expert_id = topk_ids[i]; int warp_idx = expert_id / experts_per_warp; int expert_offset = expert_id % experts_per_warp; - atomicAdd(&shared_counts[warp_idx][expert_offset], 1); + atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1); } __syncthreads(); - // Single thread computes cumulative sum and total tokens if (threadIdx.x == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { int expert_count = 0; int warp_idx = (i - 1) / experts_per_warp; int expert_offset = (i - 1) % experts_per_warp; - expert_count = shared_counts[warp_idx][expert_offset]; + expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; @@ -248,7 +61,6 @@ __global__ void sgl_moe_align_block_size_kernel( __syncthreads(); - // Assign expert IDs to blocks if (threadIdx.x < num_experts) { for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { @@ -257,13 +69,11 @@ __global__ void sgl_moe_align_block_size_kernel( } } -// taken from -// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 template -__global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, - int32_t* sorted_token_ids, - int32_t* cumsum_buffer, - size_t numel) { +__global__ void count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, + size_t numel) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; @@ -290,132 +100,138 @@ __global__ void moe_sum_kernel( } } +template +__global__ void moe_align_block_size_small_batch_expert_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel) { + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]]; + } + + __syncthreads(); + + if (threadIdx.x < num_experts) { + tokens_cnts[threadIdx.x] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[i * num_experts + threadIdx.x] += + tokens_cnts[(i - 1) * num_experts + threadIdx.x]; + } + } + + __syncthreads(); + + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = + cumsum[i - 1] + + CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * + block_size; + } + *total_tokens_post_pad = static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + int32_t rank_post_pad = + tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[threadIdx.x * num_experts + expert_id]; + } +} + } // namespace moe } // namespace vllm +// taken from +// https://github.com/sgl-project/sglang/blob/8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - int device_max_shared_mem; - auto dev = topk_ids.get_device(); - cudaDeviceGetAttribute(&device_max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - - const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - const int32_t shared_mem_i32 = - ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); - const int32_t shared_mem_i16 = - ((num_thread + 1) * num_experts) * sizeof(uint16_t) + - (num_experts + 1) * sizeof(int32_t); - - bool use_global_memory = false; - bool use_i16 = false; // Use uint16_t for shared memory token counts - if (shared_mem_i32 < device_max_shared_mem) { - // Do nothing in this case. We're all set to use int32_t token counts - } else if (shared_mem_i16 < device_max_shared_mem && - topk_ids.numel() <= 65535) { - // when nelements of topk_ids is smaller than 65535 (max value of uint16), - // element value of token_cnts would also smaller than 65535, - // so we can use uint16 as dtype of token_cnts - use_i16 = true; - } else { - use_global_memory = true; - } - - if (use_global_memory) { - VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` - // tensors - const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - - auto options_int = torch::TensorOptions() - .dtype(torch::kInt) - .device(topk_ids.device()); - torch::Tensor token_cnts_buffer = - torch::empty({(num_experts + 1) * num_experts}, options_int); - torch::Tensor cumsum_buffer = - torch::empty({num_experts + 1}, options_int); - - auto kernel = - vllm::moe::moe_align_block_size_global_mem_kernel; - kernel<<<1, num_thread, 0, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel(), token_cnts_buffer.data_ptr(), - cumsum_buffer.data_ptr()); - }); - } else if (use_i16) { - VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // set dynamic shared mem - auto kernel = - vllm::moe::moe_align_block_size_kernel; - AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( - (void*)kernel, shared_mem_i16)); - kernel<<<1, num_thread, shared_mem_i16, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel()); - }); - } else { - VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - auto kernel = - vllm::moe::moe_align_block_size_kernel; - AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( - (void*)kernel, shared_mem_i32)); - kernel<<<1, num_thread, shared_mem_i32, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel()); - }); - } -} - -void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - TORCH_CHECK(num_experts == 256, - "sgl_moe_align_block_size kernel only supports deepseek v3."); + int64_t padded_num_experts = + ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int experts_per_warp = WARP_SIZE; + int threads = 1024; + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( - topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `cumsum` tensors auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); torch::Tensor cumsum_buffer = torch::zeros({num_experts + 1}, options_int); + bool small_batch_expert_mode = + (topk_ids.numel() < 1024) && (num_experts <= 64); - auto align_kernel = - vllm::moe::sgl_moe_align_block_size_kernel; - align_kernel<<<1, 1024, 0, stream>>>( - topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel(), cumsum_buffer.data_ptr()); + if (small_batch_expert_mode) { + const int32_t threads = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem_size = + ((threads + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); - const int block_threads = 256; - const int num_blocks = - (topk_ids.numel() + block_threads - 1) / block_threads; - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel; - sort_kernel<<>>( - topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), topk_ids.numel()); + auto small_batch_expert_kernel = + vllm::moe::moe_align_block_size_small_batch_expert_kernel< + scalar_t>; + small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel()); + } else { + auto align_kernel = vllm::moe::moe_align_block_size_kernel; + + size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); + size_t shared_mem_size = + num_warps * experts_per_warp * sizeof(int32_t); + + align_kernel<<<1, threads, shared_mem_size, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, + padded_num_experts, experts_per_warp, block_size, + topk_ids.numel(), cumsum_buffer.data_ptr()); + + const int block_threads = std::min(256, (int)threads); + const int num_blocks = + (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + auto sort_kernel = + vllm::moe::count_and_sort_expert_tokens_kernel; + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), topk_ids.numel()); + } }); } diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index c4faef7310..661730c968 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -12,12 +12,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); - -void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d6ef4940b6..97df311d04 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -22,15 +22,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor! num_tokens_post_pad) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); - // temporarily adapted from - // https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a - m.def( - "sgl_moe_align_block_size(Tensor topk_ids, int num_experts," - " int block_size, Tensor! sorted_token_ids," - " Tensor! experts_ids," - " Tensor! num_tokens_post_pad) -> ()"); - m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); - #ifndef USE_ROCM m.def( "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " diff --git a/tests/kernels/moe/test_moe_align_block_size.py b/tests/kernels/moe/test_moe_align_block_size.py new file mode 100644 index 0000000000..e980422a7b --- /dev/null +++ b/tests/kernels/moe/test_moe_align_block_size.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size_triton) + + +@pytest.mark.parametrize( + "block_size,num_tokens,topk,num_experts", + list( + itertools.product( + [32, 64, 128, 256], # block_size + [ + 1, + 3, + 7, + 16, + 256, + 2256, + 4096, + ], # num_tokens + [1, 4, 16, 64], # topk + [64, 160, 256, 257, 260, 264], # num_experts + )), +) +def test_moe_align_block_size_compare_implementations(block_size, num_tokens, + topk, num_experts): + topk_ids = torch.stack([ + torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] + for _ in range(num_tokens) + ]) + + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + + sorted_ids_cuda = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids_cuda.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids_cuda = torch.zeros((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad_cuda = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + + sorted_ids_triton = torch.empty_like(sorted_ids_cuda) + sorted_ids_triton.fill_(topk_ids.numel()) + expert_ids_triton = torch.zeros_like(expert_ids_cuda) + num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) + + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids_cuda, + expert_ids_cuda, + num_tokens_post_pad_cuda, + ) + + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids_triton, + expert_ids_triton, + num_tokens_post_pad_triton, + ) + + assert torch.allclose(expert_ids_cuda, expert_ids_triton), ( + f"Expert IDs mismatch for block_size={block_size}, " + f"num_tokens={num_tokens}, topk={topk}\n" + f"CUDA expert_ids: {expert_ids_cuda}\n" + f"Triton expert_ids: {expert_ids_triton}") + + assert torch.allclose( + num_tokens_post_pad_cuda, num_tokens_post_pad_triton), ( + f"Num tokens post pad mismatch for block_size={block_size}, " + f"num_tokens={num_tokens}, topk={topk}\n" + f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n" + f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}") + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ff992c33b3..b16fef8714 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1524,15 +1524,6 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, num_tokens_post_pad) -def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, - block_size: int, sorted_token_ids: torch.Tensor, - experts_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor) -> None: - torch.ops._moe_C.sgl_moe_align_block_size(topk_ids, num_experts, - block_size, sorted_token_ids, - experts_ids, num_tokens_post_pad) - - def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, b_qweight: torch.Tensor, b_scales: torch.Tensor, b_qzeros: Optional[torch.Tensor], diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index 9d990959e0..f9451ca2fd 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -4,7 +4,6 @@ from typing import Optional import torch -import vllm.envs as envs from vllm import _custom_ops as ops from vllm.triton_utils import tl, triton from vllm.utils import round_up @@ -99,6 +98,7 @@ def moe_align_block_size_stage4( # Triton implementation based on: # https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +# TODO(wentao): Deprecated this function in the future. def moe_align_block_size_triton( topk_ids: torch.Tensor, num_experts: int, @@ -220,29 +220,9 @@ def moe_align_block_size( num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - if num_experts >= 224: - if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256: - moe_align_block_size_triton( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - ) - else: - # Currently requires num_experts=256 - ops.sgl_moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - ) - else: - ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) + + ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) if expert_map is not None: expert_ids = expert_map[expert_ids]