[Misc] Make timeout passable in init_distributed_environment (#24522)
Signed-off-by: jberkhahn <jaberkha@us.ibm.com>
This commit is contained in:
committed by
GitHub
parent
dcb28a332b
commit
cc99baf14d
@ -29,6 +29,7 @@ import weakref
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from unittest.mock import patch
|
||||
@ -978,13 +979,12 @@ def set_custom_all_reduce(enable: bool):
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def init_distributed_environment(
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
distributed_init_method: str = "env://",
|
||||
local_rank: int = -1,
|
||||
backend: str = "nccl",
|
||||
):
|
||||
def init_distributed_environment(world_size: int = -1,
|
||||
rank: int = -1,
|
||||
distributed_init_method: str = "env://",
|
||||
local_rank: int = -1,
|
||||
backend: str = "nccl",
|
||||
timeout: Optional[timedelta] = None):
|
||||
logger.debug(
|
||||
"world_size=%d rank=%d local_rank=%d "
|
||||
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
|
||||
@ -1020,7 +1020,8 @@ def init_distributed_environment(
|
||||
backend=backend,
|
||||
init_method=distributed_init_method,
|
||||
world_size=world_size,
|
||||
rank=rank)
|
||||
rank=rank,
|
||||
timeout=timeout)
|
||||
# set the local rank
|
||||
# local_rank is not available in torch ProcessGroup,
|
||||
# see https://github.com/pytorch/pytorch/issues/122816
|
||||
|
||||
Reference in New Issue
Block a user