[Bugfix][Rocm] fix qr error when different inp shape (#25892)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
haoyangli-amd
2025-10-14 01:04:21 +08:00
committed by GitHub
parent a1b2d658ee
commit 134f70b3ed
3 changed files with 96 additions and 11 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import random
import pytest
@ -8,6 +9,7 @@ import ray
import torch
import torch.distributed as dist
from vllm import _custom_ops as ops
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.parallel_state import get_tp_group, graph_capture
from vllm.platforms import current_platform
@ -134,3 +136,88 @@ def test_custom_quick_allreduce(
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
def qr_variable_input(rank, world_size):
"""
When the tensor parallelism is set to 4 or 8, frequent changes
in the input shape can cause QuickReduce to hang (this issue
has been observed with the gpt_oss model).
"""
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
qr_max_size = None # MB
_ptr = ops.init_custom_qr(rank, world_size, qr_max_size)
ranks = []
for i in range(world_size):
ranks.append(i)
dist.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=world_size,
)
cpu_group = torch.distributed.new_group(ranks, backend="nccl")
handle = ops.qr_get_handle(_ptr)
world_size = dist.get_world_size(group=cpu_group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=cpu_group)
ops.qr_open_handles(_ptr, handles)
num = 1
s1 = 1024
while num < 50000: # 50000 is sufficient to identify issues.
dtype = torch.float16
if num % 2 == 0:
s2 = 1024
inp1 = torch.zeros(
(s1, s2), dtype=dtype, device=torch.cuda.current_device()
)
else:
s2 = 2048
inp1 = torch.ones((s1, s2), dtype=dtype, device=torch.cuda.current_device())
result = torch.empty_like(inp1)
# FP = 0 INT8 = 1 INT6 = 2 INT4 = 3 NONE = 4
ops.qr_all_reduce(_ptr, inp1, result, 3, cast_bf2half=True)
try:
if inp1[0, 0] == 0:
assert torch.all(result == 0)
else:
assert torch.all(result == world_size)
except AssertionError:
print("Assertion failed! Allreduce results are incorrect.")
raise
num += 1
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
)
@pytest.mark.parametrize("tp_size", [4, 8])
@pytest.mark.parametrize("pipeline_parallel_size", [1])
def test_custom_quick_allreduce_variable_input(tp_size, pipeline_parallel_size):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
multiprocessing.set_start_method("spawn", force=True)
# 60s is enough
timeout = 60
processes = []
for rank in range(tp_size):
p = multiprocessing.Process(target=qr_variable_input, args=(rank, tp_size))
p.start()
processes.append((rank, p))
for rank, p in processes:
p.join(timeout=timeout)
if p.is_alive():
for r, proc in processes:
if proc.is_alive():
proc.terminate()
proc.join()
raise RuntimeError(f"QuickReduce hang detected after {timeout} seconds!")
if __name__ == "__main__":
test_custom_quick_allreduce_variable_input(tp_size=4, pipeline_parallel_size=1)