[Misc] Add pynccl wrappers for all_gather and reduce_scatter (#9432)
This commit is contained in:
committed by
GitHub
parent
ebda51968b
commit
978b39744b
@ -150,6 +150,75 @@ def worker_fn_with_cudagraph():
|
||||
assert a.mean().cpu().item() == pynccl_comm.world_size**1
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def all_gather_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
|
||||
rank = pynccl_comm.rank
|
||||
world_size = pynccl_comm.world_size
|
||||
device = f'cuda:{pynccl_comm.rank}'
|
||||
|
||||
num_elems = 1000
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||
device=device) + rank * num_elems
|
||||
result = torch.zeros(num_elems * world_size,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
expected = torch.cat([
|
||||
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
||||
for r in range(world_size)
|
||||
]).to(device)
|
||||
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
pynccl_comm.all_gather(result, tensor)
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_all_gather():
|
||||
distributed_run(all_gather_worker_fn, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def reduce_scatter_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
|
||||
rank = pynccl_comm.rank
|
||||
world_size = pynccl_comm.world_size
|
||||
device = f'cuda:{pynccl_comm.rank}'
|
||||
|
||||
num_elems = 1000
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||
device=device) + rank * num_elems
|
||||
assert (num_elems % world_size == 0)
|
||||
result = torch.zeros(num_elems // world_size,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
# Calculate expected result for this rank's chunk
|
||||
scattered_size = num_elems // world_size
|
||||
all_tensors = [
|
||||
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
||||
for r in range(world_size)
|
||||
]
|
||||
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
|
||||
for tensor in all_tensors).to(device)
|
||||
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
pynccl_comm.reduce_scatter(result, tensor)
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_reduce_scatter():
|
||||
distributed_run(reduce_scatter_worker_fn, 2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_with_cudagraph():
|
||||
|
||||
Reference in New Issue
Block a user