[Distributed] Add send and recv helpers (#5719)
This commit is contained in:
committed by
GitHub
parent
6c916ac8a8
commit
5d4d90536f
@ -8,12 +8,11 @@ import pytest
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
|
||||
from ..utils import (init_test_distributed_environment,
|
||||
multi_process_tensor_parallel)
|
||||
from ..utils import init_test_distributed_environment, multi_process_parallel
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
@ -105,6 +104,68 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
assert torch.allclose(recv_dict["f"], test_dict["f"])
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
|
||||
test_dict = {
|
||||
# device tensor
|
||||
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
||||
# CPU tensor
|
||||
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
||||
"c": "test",
|
||||
"d": [1, 2, 3],
|
||||
"e": {
|
||||
"a": 1,
|
||||
"b": 2
|
||||
},
|
||||
# empty tensor
|
||||
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||
}
|
||||
|
||||
if not get_pp_group().is_first_rank:
|
||||
recv_dict = get_pp_group().recv_tensor_dict()
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
get_pp_group().send_tensor_dict(test_dict)
|
||||
|
||||
if not get_pp_group().is_first_rank:
|
||||
assert len(recv_dict) == len(test_dict)
|
||||
assert torch.allclose(recv_dict["a"], test_dict["a"])
|
||||
assert torch.allclose(recv_dict["b"], test_dict["b"])
|
||||
assert recv_dict["c"] == test_dict["c"]
|
||||
assert recv_dict["d"] == test_dict["d"]
|
||||
assert recv_dict["e"] == test_dict["e"]
|
||||
assert torch.allclose(recv_dict["f"], test_dict["f"])
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
|
||||
size = 64
|
||||
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
|
||||
|
||||
if not get_pp_group().is_first_rank:
|
||||
recv_tensor = get_pp_group().recv(size, dtype=torch.float32)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
get_pp_group().send(test_tensor)
|
||||
|
||||
if not get_pp_group().is_first_rank:
|
||||
assert torch.allclose(test_tensor, recv_tensor)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@ -113,4 +174,13 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
broadcast_tensor_dict_test_worker
|
||||
])
|
||||
def test_multi_process_tensor_parallel(tp_size, test_target):
|
||||
multi_process_tensor_parallel(tp_size, 1, test_target)
|
||||
multi_process_parallel(tp_size, 1, test_target)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("pp_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
|
||||
def test_multi_process_pipeline_parallel(pp_size, test_target):
|
||||
multi_process_parallel(1, pp_size, test_target)
|
||||
|
||||
@ -12,8 +12,7 @@ from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
||||
get_tp_group, graph_capture)
|
||||
|
||||
from ..utils import (ensure_model_parallel_initialized,
|
||||
init_test_distributed_environment,
|
||||
multi_process_tensor_parallel)
|
||||
init_test_distributed_environment, multi_process_parallel)
|
||||
|
||||
random.seed(42)
|
||||
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
|
||||
@ -113,4 +112,4 @@ def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
|
||||
world_size = tp_size * pipeline_parallel_size
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)
|
||||
multi_process_parallel(tp_size, pipeline_parallel_size, test_target)
|
||||
|
||||
@ -168,9 +168,13 @@ def send_recv_worker_fn():
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
if pynccl_comm.rank == 0:
|
||||
pynccl_comm.send(tensor)
|
||||
pynccl_comm.send(tensor,
|
||||
dst=(pynccl_comm.rank + 1) %
|
||||
pynccl_comm.world_size)
|
||||
else:
|
||||
pynccl_comm.recv(tensor)
|
||||
pynccl_comm.recv(tensor,
|
||||
src=(pynccl_comm.rank - 1) %
|
||||
pynccl_comm.world_size)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 1
|
||||
|
||||
@ -203,9 +207,13 @@ def multiple_send_recv_worker_fn():
|
||||
device=device)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
pynccl_comm.send(tensor)
|
||||
pynccl_comm.send(tensor,
|
||||
dst=(pynccl_comm.rank + 1) %
|
||||
pynccl_comm.world_size)
|
||||
else:
|
||||
pynccl_comm.recv(tensor)
|
||||
pynccl_comm.recv(tensor,
|
||||
src=(pynccl_comm.rank - 1) %
|
||||
pynccl_comm.world_size)
|
||||
result = tensor.mean().cpu().item()
|
||||
if torch.distributed.get_rank() in [0, 2]:
|
||||
assert result == 1
|
||||
|
||||
@ -129,7 +129,7 @@ def init_test_distributed_environment(
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
|
||||
|
||||
def multi_process_tensor_parallel(
|
||||
def multi_process_parallel(
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
test_target,
|
||||
|
||||
Reference in New Issue
Block a user