diff --git a/vllm/envs.py b/vllm/envs.py index d7ba43c825..a62eeac2b0 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -338,6 +338,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), + # used to control the visible devices in the distributed setting + "VLLM_VISIBLE_DEVICES": + lambda: os.environ.get("VLLM_VISIBLE_DEVICES", None), + # timeout for each iteration in the engine "VLLM_ENGINE_ITERATION_TIMEOUT_S": lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")), diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 6b30acee1d..0441e6c8a8 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -135,7 +135,14 @@ class Worker(WorkerBase): # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - self.device = torch.device(f"cuda:{self.local_rank}") + + device_id = self.local_rank + if envs.VLLM_VISIBLE_DEVICES is not None: + devices = [ + int(dev) for dev in (x.strip() for x in envs.VLLM_VISIBLE_DEVICES.split(',')) + ] + device_id = devices[self.local_rank] + self.device = torch.device(f"cuda:{device_id}") current_platform.set_device(self.device) _check_if_gpu_supports_dtype(self.model_config.dtype)