[Misc] Make timeout passable in init_distributed_environment (#24522)

Signed-off-by: jberkhahn <jaberkha@us.ibm.com>
This commit is contained in:
Jonathan Berkhahn
2025-09-10 15:41:12 -07:00
committed by GitHub
parent dcb28a332b
commit cc99baf14d

View File

@ -29,6 +29,7 @@ import weakref
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory
from typing import Any, Callable, Optional, Union
from unittest.mock import patch
@ -978,13 +979,12 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
distributed_init_method: str = "env://",
local_rank: int = -1,
backend: str = "nccl",
):
def init_distributed_environment(world_size: int = -1,
rank: int = -1,
distributed_init_method: str = "env://",
local_rank: int = -1,
backend: str = "nccl",
timeout: Optional[timedelta] = None):
logger.debug(
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
@ -1020,7 +1020,8 @@ def init_distributed_environment(
backend=backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank)
rank=rank,
timeout=timeout)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816