[distributed] remove pynccl's redundant stream (#11744)
This commit is contained in:
@ -137,9 +137,8 @@ def worker_fn_with_cudagraph():
|
||||
# run something in the default stream to initialize torch engine
|
||||
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.graph(
|
||||
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
|
||||
enable=True):
|
||||
with torch.cuda.graph(graph), \
|
||||
pynccl_comm.change_state(enable=True):
|
||||
a_out = pynccl_comm.all_reduce(a)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
|
||||
Reference in New Issue
Block a user