[CI] Add tests for cudagraph (#27391)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
@ -1111,6 +1111,11 @@ def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]
|
||||
# `cloudpickle` allows pickling complex functions directly
|
||||
input_bytes = cloudpickle.dumps((f, output_filepath))
|
||||
|
||||
repo_root = str(VLLM_PATH.resolve())
|
||||
|
||||
env = dict(env or os.environ)
|
||||
env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "")
|
||||
|
||||
cmd = [sys.executable, "-m", f"{module_name}"]
|
||||
|
||||
returned = subprocess.run(
|
||||
|
||||
@ -34,13 +34,16 @@ class SimpleMLP(nn.Module):
|
||||
|
||||
|
||||
def _create_vllm_config(
|
||||
compilation_config: CompilationConfig, max_num_seqs: int = 8
|
||||
compilation_config: CompilationConfig,
|
||||
max_num_seqs: int = 8,
|
||||
lora_config: bool = False,
|
||||
) -> MagicMock:
|
||||
mock_config = MagicMock(spec=VllmConfig)
|
||||
mock_config.compilation_config = compilation_config
|
||||
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
|
||||
mock_config.parallel_config = ParallelConfig()
|
||||
|
||||
if not lora_config:
|
||||
mock_config.lora_config = None
|
||||
# Mimic the behavior of VllmConfig.__post_init__()
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
compilation_config.set_splitting_ops_for_v1()
|
||||
@ -50,19 +53,21 @@ def _create_vllm_config(
|
||||
|
||||
class TestCudagraphDispatcher:
|
||||
@pytest.mark.parametrize(
|
||||
"case_id,cudagraph_mode_str,compilation_mode",
|
||||
"cudagraph_mode_str,compilation_mode,lora_config",
|
||||
[
|
||||
# Test case 0: Full CG for mixed batches, no separate routine
|
||||
(0, "FULL", CompilationMode.NONE),
|
||||
("FULL", CompilationMode.NONE, False),
|
||||
# Test case 1: Full CG for uniform batches, piecewise for mixed
|
||||
(1, "FULL_AND_PIECEWISE", CompilationMode.NONE),
|
||||
("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
|
||||
# Test case 2: Full CG for uniform batches, no CG for mixed
|
||||
(2, "FULL_DECODE_ONLY", CompilationMode.NONE),
|
||||
("FULL_DECODE_ONLY", CompilationMode.NONE, False),
|
||||
# Test case 3: PIECEWISE for all
|
||||
(3, "PIECEWISE", CompilationMode.VLLM_COMPILE),
|
||||
("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
|
||||
# Test case 4: PIECEWISE for all, specialize LoRA cases
|
||||
("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||
],
|
||||
)
|
||||
def test_dispatcher(self, cudagraph_mode_str, compilation_mode):
|
||||
def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
|
||||
# Setup dispatcher
|
||||
comp_config = CompilationConfig(
|
||||
cudagraph_mode=cudagraph_mode_str,
|
||||
@ -70,7 +75,17 @@ class TestCudagraphDispatcher:
|
||||
cudagraph_capture_sizes=[1, 8],
|
||||
)
|
||||
|
||||
config = _create_vllm_config(comp_config, max_num_seqs=8)
|
||||
config = _create_vllm_config(
|
||||
comp_config, max_num_seqs=8, lora_config=lora_config
|
||||
)
|
||||
if (
|
||||
cudagraph_mode_str == "FULL_AND_PIECEWISE"
|
||||
and compilation_mode == CompilationMode.NONE
|
||||
):
|
||||
with pytest.raises(AssertionError):
|
||||
dispatcher = CudagraphDispatcher(config)
|
||||
return
|
||||
|
||||
dispatcher = CudagraphDispatcher(config)
|
||||
dispatcher.initialize_cudagraph_keys(
|
||||
cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
|
||||
@ -78,17 +93,24 @@ class TestCudagraphDispatcher:
|
||||
|
||||
# Verify the key is initialized correctly
|
||||
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
|
||||
4 if lora_config else 2
|
||||
)
|
||||
else:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
|
||||
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
|
||||
4 if lora_config else 2
|
||||
)
|
||||
else:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
|
||||
|
||||
# Test dispatch logic
|
||||
# 1. non-uniform batch, size in cudagraph size list
|
||||
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
|
||||
desc_full_exact = BatchDescriptor(
|
||||
num_tokens=8,
|
||||
uniform_decode=False,
|
||||
)
|
||||
rt_mode, key = dispatcher.dispatch(desc_full_exact)
|
||||
if cudagraph_mode_str == "FULL":
|
||||
assert rt_mode == CUDAGraphMode.FULL
|
||||
@ -138,7 +160,6 @@ class TestCUDAGraphWrapper:
|
||||
self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
|
||||
self.input_tensor = torch.randn(1, 10, device="cuda")
|
||||
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_capture_and_replay(self):
|
||||
wrapper = CUDAGraphWrapper(
|
||||
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
||||
@ -192,7 +213,6 @@ class TestCUDAGraphWrapper:
|
||||
eager_output = self.model(self.input_tensor)
|
||||
torch.testing.assert_close(eager_output, output2)
|
||||
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_bypass_on_mode_mismatch(self):
|
||||
wrapper = CUDAGraphWrapper(
|
||||
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
||||
@ -216,7 +236,6 @@ class TestCUDAGraphWrapper:
|
||||
mock_forward.assert_called_once()
|
||||
assert not wrapper.concrete_cudagraph_entries
|
||||
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_bypass_on_mode_none(self):
|
||||
wrapper = CUDAGraphWrapper(
|
||||
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
||||
|
||||
@ -109,9 +109,9 @@ combo_cases_2 = [
|
||||
@pytest.mark.parametrize(
|
||||
"backend_name,cudagraph_mode,compilation_mode,supported", combo_cases_2
|
||||
)
|
||||
def test_cudagraph_compilation_combo(combo_case):
|
||||
backend_name, cudagraph_mode, compilation_mode, supported = combo_case
|
||||
|
||||
def test_cudagraph_compilation_combo(
|
||||
backend_name, cudagraph_mode, compilation_mode, supported
|
||||
):
|
||||
env_vars = backend_configs[backend_name].env_vars
|
||||
|
||||
with temporary_environ(env_vars), ExitStack() as stack:
|
||||
|
||||
Reference in New Issue
Block a user