From fcb1d570bb8f95f5b7ded716a52fec902c535f0e Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 29 Oct 2025 14:50:39 -0400 Subject: [PATCH] [Bug] Fix DeepEP low latency `assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)` Bug (#27682) Signed-off-by: yewentao256 --- vllm/model_executor/layers/fused_moe/layer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 294dddade6..7dbe4bc543 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1135,6 +1135,7 @@ class FusedMoE(CustomOp): ) self.global_num_experts = num_experts + num_redundant_experts + self.logical_num_experts = num_experts self.zero_expert_num = zero_expert_num self.zero_expert_type = zero_expert_type @@ -1998,13 +1999,12 @@ class FusedMoE(CustomOp): moe = self.moe_config - # Note here we use `num_experts` which is logical expert count if self.vllm_config.parallel_config.enable_dbo: states_shape = (2, moe.max_num_tokens, self.hidden_size) - logits_shape = (2, moe.max_num_tokens, moe.num_experts) + logits_shape = (2, moe.max_num_tokens, self.logical_num_experts) else: states_shape = (moe.max_num_tokens, self.hidden_size) - logits_shape = (moe.max_num_tokens, moe.num_experts) + logits_shape = (moe.max_num_tokens, self.logical_num_experts) self.batched_hidden_states = torch.zeros( states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()