custom allreduce + torch.compile (#10121)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@ -60,7 +60,7 @@ def worker_fn():
|
||||
tensor = torch.ones(16, 1024, 1024,
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
pynccl_comm.all_reduce(tensor)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == pynccl_comm.world_size
|
||||
|
||||
@ -84,12 +84,12 @@ def multiple_allreduce_worker_fn():
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
# two groups can communicate independently
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
pynccl_comm.all_reduce(tensor)
|
||||
pynccl_comm.all_reduce(tensor)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 4
|
||||
else:
|
||||
pynccl_comm.all_reduce(tensor)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 2
|
||||
|
||||
@ -140,14 +140,11 @@ def worker_fn_with_cudagraph():
|
||||
with torch.cuda.graph(
|
||||
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
|
||||
enable=True):
|
||||
# operation during the graph capture is recorded but not executed
|
||||
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
|
||||
pynccl_comm.all_reduce(a)
|
||||
a_out = pynccl_comm.all_reduce(a)
|
||||
pynccl_comm.stream.synchronize()
|
||||
assert a.mean().cpu().item() == pynccl_comm.world_size**0
|
||||
graph.replay()
|
||||
pynccl_comm.stream.synchronize()
|
||||
assert a.mean().cpu().item() == pynccl_comm.world_size**1
|
||||
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
|
||||
@ -70,14 +70,12 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
||||
pynccl1.disabled = False
|
||||
if rank <= 2:
|
||||
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port2,
|
||||
rank=rank,
|
||||
world_size=3)
|
||||
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
||||
pynccl2.disabled = False
|
||||
data = torch.tensor([rank]).cuda()
|
||||
pynccl1.all_reduce(data)
|
||||
pg1.barrier()
|
||||
|
||||
Reference in New Issue
Block a user