[distributed] remove pynccl's redundant stream (#11744)

This commit is contained in:
cennn
2025-01-05 23:09:11 +08:00
committed by GitHub
parent 4068f4b5b5
commit 635b897246
3 changed files with 12 additions and 24 deletions

View File

@ -137,9 +137,8 @@ def worker_fn_with_cudagraph():
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
torch.cuda.synchronize()
with torch.cuda.graph(
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
enable=True):
with torch.cuda.graph(graph), \
pynccl_comm.change_state(enable=True):
a_out = pynccl_comm.all_reduce(a)
torch.cuda.synchronize()
graph.replay()