[Core][Distributed] enable multiple tp group (#4512)

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
youkaichao
2024-05-01 21:28:21 -07:00
committed by GitHub
parent cf8cac8c70
commit 2a85f93007
4 changed files with 43 additions and 4 deletions

View File

@ -58,6 +58,34 @@ def test_pynccl():
distributed_run(worker_fn, 2)
@worker_fn_wrapper
def multiple_tp_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
torch.distributed.new_group(ranks=[2, 3], backend="gloo")
]
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
comm = NCCLCommunicator(group=group, device=device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
comm.all_reduce(tensor)
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 4
else:
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 2
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_multiple_tp():
distributed_run(worker_fn, 4)
@worker_fn_wrapper
def worker_fn_with_cudagraph():
with torch.no_grad():