[CI] Add tests for cudagraph (#27391)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu
2025-10-25 10:37:33 +08:00
committed by GitHub
parent 83f478bb19
commit 29c9cb8007
4 changed files with 54 additions and 18 deletions

View File

@ -1111,6 +1111,11 @@ def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]
# `cloudpickle` allows pickling complex functions directly
input_bytes = cloudpickle.dumps((f, output_filepath))
repo_root = str(VLLM_PATH.resolve())
env = dict(env or os.environ)
env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "")
cmd = [sys.executable, "-m", f"{module_name}"]
returned = subprocess.run(

View File

@ -34,13 +34,16 @@ class SimpleMLP(nn.Module):
def _create_vllm_config(
compilation_config: CompilationConfig, max_num_seqs: int = 8
compilation_config: CompilationConfig,
max_num_seqs: int = 8,
lora_config: bool = False,
) -> MagicMock:
mock_config = MagicMock(spec=VllmConfig)
mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
mock_config.parallel_config = ParallelConfig()
if not lora_config:
mock_config.lora_config = None
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1()
@ -50,19 +53,21 @@ def _create_vllm_config(
class TestCudagraphDispatcher:
@pytest.mark.parametrize(
"case_id,cudagraph_mode_str,compilation_mode",
"cudagraph_mode_str,compilation_mode,lora_config",
[
# Test case 0: Full CG for mixed batches, no separate routine
(0, "FULL", CompilationMode.NONE),
("FULL", CompilationMode.NONE, False),
# Test case 1: Full CG for uniform batches, piecewise for mixed
(1, "FULL_AND_PIECEWISE", CompilationMode.NONE),
("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
# Test case 2: Full CG for uniform batches, no CG for mixed
(2, "FULL_DECODE_ONLY", CompilationMode.NONE),
("FULL_DECODE_ONLY", CompilationMode.NONE, False),
# Test case 3: PIECEWISE for all
(3, "PIECEWISE", CompilationMode.VLLM_COMPILE),
("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
# Test case 4: PIECEWISE for all, specialize LoRA cases
("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
],
)
def test_dispatcher(self, cudagraph_mode_str, compilation_mode):
def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
# Setup dispatcher
comp_config = CompilationConfig(
cudagraph_mode=cudagraph_mode_str,
@ -70,7 +75,17 @@ class TestCudagraphDispatcher:
cudagraph_capture_sizes=[1, 8],
)
config = _create_vllm_config(comp_config, max_num_seqs=8)
config = _create_vllm_config(
comp_config, max_num_seqs=8, lora_config=lora_config
)
if (
cudagraph_mode_str == "FULL_AND_PIECEWISE"
and compilation_mode == CompilationMode.NONE
):
with pytest.raises(AssertionError):
dispatcher = CudagraphDispatcher(config)
return
dispatcher = CudagraphDispatcher(config)
dispatcher.initialize_cudagraph_keys(
cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
@ -78,17 +93,24 @@ class TestCudagraphDispatcher:
# Verify the key is initialized correctly
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
4 if lora_config else 2
)
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
4 if lora_config else 2
)
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
# Test dispatch logic
# 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
desc_full_exact = BatchDescriptor(
num_tokens=8,
uniform_decode=False,
)
rt_mode, key = dispatcher.dispatch(desc_full_exact)
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
@ -138,7 +160,6 @@ class TestCUDAGraphWrapper:
self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
self.input_tensor = torch.randn(1, 10, device="cuda")
@create_new_process_for_each_test("spawn")
def test_capture_and_replay(self):
wrapper = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
@ -192,7 +213,6 @@ class TestCUDAGraphWrapper:
eager_output = self.model(self.input_tensor)
torch.testing.assert_close(eager_output, output2)
@create_new_process_for_each_test("spawn")
def test_bypass_on_mode_mismatch(self):
wrapper = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
@ -216,7 +236,6 @@ class TestCUDAGraphWrapper:
mock_forward.assert_called_once()
assert not wrapper.concrete_cudagraph_entries
@create_new_process_for_each_test("spawn")
def test_bypass_on_mode_none(self):
wrapper = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL

View File

@ -109,9 +109,9 @@ combo_cases_2 = [
@pytest.mark.parametrize(
"backend_name,cudagraph_mode,compilation_mode,supported", combo_cases_2
)
def test_cudagraph_compilation_combo(combo_case):
backend_name, cudagraph_mode, compilation_mode, supported = combo_case
def test_cudagraph_compilation_combo(
backend_name, cudagraph_mode, compilation_mode, supported
):
env_vars = backend_configs[backend_name].env_vars
with temporary_environ(env_vars), ExitStack() as stack: