From cfbca8a2f2b3dafe210d0bffa944f00a40112ac0 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Date: Wed, 19 Mar 2025 20:55:18 -0400 Subject: [PATCH] [V1] TPU - Tensor parallel MP support (#15059) --- vllm/config.py | 2 +- .../device_communicators/tpu_communicator.py | 48 ++++++++++++++----- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index c248122da0..2d8f1ba483 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1473,7 +1473,7 @@ class ParallelConfig: os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" logger.info("Disabling V1 multiprocessing for external launcher.") - ray_only_devices = ["tpu"] + ray_only_devices: list[str] = [] from vllm.platforms import current_platform if (current_platform.device_type in ray_only_devices and self.world_size > 1): diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 524e655b6b..05cb1e0f6e 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -6,16 +6,25 @@ from typing import Optional import torch from torch.distributed import ProcessGroup +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger from vllm.platforms import current_platform from .base_device_communicator import DeviceCommunicatorBase +USE_RAY = parallel_config = get_current_vllm_config( +).parallel_config.distributed_executor_backend == "ray" + +logger = init_logger(__name__) + if current_platform.is_tpu(): + import torch_xla import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import pjrt - from vllm.executor import ray_utils + if USE_RAY: + from vllm.executor import ray_utils class TpuCommunicator(DeviceCommunicatorBase): @@ -33,19 +42,32 @@ class TpuCommunicator(DeviceCommunicatorBase): global_rank = self.global_rank global_world_size = self.global_world_size - # Calculate how many TPU nodes are in the current deployment. This - # is the Ray placement group if it is deployed with Ray. Default - # to the number of TPU nodes in the Ray cluster. The number of TPU - # nodes is computed by the total number of TPUs divided by the - # number of TPU accelerators per node, to account for clusters - # with both CPUs and TPUs. - num_nodes = ray_utils.get_num_tpu_nodes() - num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() - if num_nodes_in_pg > 0: - num_nodes = num_nodes_in_pg + if USE_RAY: + logger.info("TpuCommunicator initialized with RAY") + # Calculate how many TPU nodes are in the current deployment. This + # is the Ray placement group if it is deployed with Ray. Default + # to the number of TPU nodes in the Ray cluster. The number of TPU + # nodes is computed by the total number of TPUs divided by the + # number of TPU accelerators per node, to account for clusters + # with both CPUs and TPUs. + num_nodes = ray_utils.get_num_tpu_nodes() + num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() + if num_nodes_in_pg > 0: + num_nodes = num_nodes_in_pg - local_world_size = global_world_size // num_nodes - local_rank = global_rank % local_world_size + local_world_size = global_world_size // num_nodes + local_rank = global_rank % local_world_size + else: + logger.info("TpuCommunicator initialized with MP") + # Sanity: Verify we run on a single host + num_hosts = torch_xla.tpu.num_tpu_workers() + assert num_hosts == 1 + + # Get the current number of TPUs (we have locally) + local_world_size = torch_xla.tpu.num_available_chips() + + # Get current rank + local_rank = global_rank % local_world_size # Ensure environment variables are set for multihost deployments. # On GKE, this is needed for libtpu and TPU driver to know which TPU