[Core][Distributed] code deduplication in tp&pp with coordinator(#5293)
[Core][Distributed] add coordinator to reduce code duplication in tp and pp (#5293)
This commit is contained in:
@ -7,9 +7,9 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.communication_op import ( # noqa
|
||||
graph_capture, tensor_model_parallel_all_reduce)
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
||||
get_tp_ca_communicator)
|
||||
get_tp_group, graph_capture)
|
||||
|
||||
from ..utils import (init_test_distributed_environment,
|
||||
multi_process_tensor_parallel)
|
||||
@ -91,7 +91,7 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
# communicate independently
|
||||
num_communication = rank // tp_size + 1
|
||||
sz = 1024
|
||||
fa = get_tp_ca_communicator()
|
||||
fa = get_tp_group().ca_comm
|
||||
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
|
||||
@ -6,10 +6,11 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.distributed.communication_op import ( # noqa
|
||||
graph_capture, tensor_model_parallel_all_reduce)
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
get_world_group, graph_capture,
|
||||
init_distributed_environment)
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
@ -53,7 +54,8 @@ def worker_fn_wrapper(fn):
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator()
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
tensor = torch.ones(16, 1024, 1024,
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
@ -129,7 +131,8 @@ def test_pynccl_multiple_allreduce_with_vllm():
|
||||
def worker_fn_with_cudagraph():
|
||||
with torch.no_grad():
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
pynccl_comm = PyNcclCommunicator()
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
# run something in the default stream to initialize torch engine
|
||||
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
||||
torch.cuda.synchronize()
|
||||
@ -154,7 +157,8 @@ def test_pynccl_with_cudagraph():
|
||||
|
||||
@worker_fn_wrapper
|
||||
def send_recv_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator()
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
if pynccl_comm.rank == 0:
|
||||
tensor = torch.ones(16, 1024, 1024,
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
|
||||
Reference in New Issue
Block a user