Enable Allgather/ReduceScatter backend for NaiveAllToAll (#23964)
Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@ -5,6 +5,7 @@ from typing import Any
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import has_deep_ep, has_pplx
|
||||
@ -69,6 +70,44 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
pass
|
||||
|
||||
|
||||
class AgRsAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
An implementation of all2all communication based on
|
||||
all-gather (dispatch) and reduce-scatter (combine).
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Reduce-scatter hidden_states across all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class PPLXAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on PPLX kernels.
|
||||
|
||||
@ -87,6 +87,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
from .all2all import NaiveAll2AllManager
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
elif all2all_backend == "allgather_reducescatter":
|
||||
from .all2all import AgRsAll2AllManager
|
||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||
logger.info("Using AllGather-ReduceScatter all2all manager.")
|
||||
elif all2all_backend == "pplx":
|
||||
from .all2all import PPLXAll2AllManager
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
|
||||
17
vllm/envs.py
17
vllm/envs.py
@ -149,8 +149,11 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
|
||||
VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx", "deepep_high_throughput",
|
||||
"deepep_low_latency"] = "naive"
|
||||
VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"allgather_reducescatter"] = \
|
||||
"allgather_reducescatter"
|
||||
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
|
||||
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
|
||||
VLLM_SLEEP_WHEN_IDLE: bool = False
|
||||
@ -1124,14 +1127,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
|
||||
# all2all backend for vllm's expert parallel communication
|
||||
# Available options:
|
||||
# - "naive": naive all2all implementation using all-reduce
|
||||
# - "naive": naive all2all implementation using broadcasts
|
||||
# - "allgather_reducescatter": all2all implementation based on allgather and
|
||||
# reducescatter
|
||||
# - "pplx": use pplx kernels
|
||||
# - "deepep_high_throughput", use deepep high-throughput kernels
|
||||
# - "deepep_low_latency", use deepep low-latency kernels
|
||||
"VLLM_ALL2ALL_BACKEND":
|
||||
env_with_choices("VLLM_ALL2ALL_BACKEND", "naive",
|
||||
env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter",
|
||||
["naive", "pplx",
|
||||
"deepep_high_throughput", "deepep_low_latency"]),
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"allgather_reducescatter"]),
|
||||
|
||||
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support.
|
||||
# Both require compute capability 10.0 or above.
|
||||
|
||||
Reference in New Issue
Block a user