Enable Allgather/ReduceScatter backend for NaiveAllToAll (#23964)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Shu Wang
2025-09-18 10:52:58 -05:00
committed by GitHub
parent b419937c78
commit 2ea50e977a
3 changed files with 55 additions and 5 deletions

View File

@ -5,6 +5,7 @@ from typing import Any
import torch
import torch.distributed as dist
from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
@ -69,6 +70,44 @@ class NaiveAll2AllManager(All2AllManagerBase):
pass
class AgRsAll2AllManager(All2AllManagerBase):
"""
An implementation of all2all communication based on
all-gather (dispatch) and reduce-scatter (combine).
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
return hidden_states
def destroy(self):
pass
class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.

View File

@ -87,6 +87,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
elif all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
logger.info("Using AllGather-ReduceScatter all2all manager.")
elif all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)

View File

@ -149,8 +149,11 @@ if TYPE_CHECKING:
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx", "deepep_high_throughput",
"deepep_low_latency"] = "naive"
VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx",
"deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter"] = \
"allgather_reducescatter"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_SLEEP_WHEN_IDLE: bool = False
@ -1124,14 +1127,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
# all2all backend for vllm's expert parallel communication
# Available options:
# - "naive": naive all2all implementation using all-reduce
# - "naive": naive all2all implementation using broadcasts
# - "allgather_reducescatter": all2all implementation based on allgather and
# reducescatter
# - "pplx": use pplx kernels
# - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_low_latency", use deepep low-latency kernels
"VLLM_ALL2ALL_BACKEND":
env_with_choices("VLLM_ALL2ALL_BACKEND", "naive",
env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter",
["naive", "pplx",
"deepep_high_throughput", "deepep_low_latency"]),
"deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter"]),
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support.
# Both require compute capability 10.0 or above.