Add ability to use CUDAGraphs with use_inductor=False (#17345)
Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
@ -74,11 +74,12 @@ class SillyModel(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def test_simple_piecewise_compile():
|
||||
def _test_simple_piecewise_compile(*, use_inductor):
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
use_inductor=use_inductor,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_copy_inputs=True,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
@ -108,3 +109,11 @@ def test_simple_piecewise_compile():
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
|
||||
|
||||
def test_simple_piecewise_compile_inductor():
|
||||
_test_simple_piecewise_compile(use_inductor=True)
|
||||
|
||||
|
||||
def test_simple_piecewise_compile_no_inductor():
|
||||
_test_simple_piecewise_compile(use_inductor=False)
|
||||
|
||||
@ -261,12 +261,14 @@ def tractable_computation(input_ids: torch.Tensor,
|
||||
@torch.inference_mode
|
||||
def run_model(llama_config,
|
||||
use_compile: bool,
|
||||
use_inductor: bool,
|
||||
split_attn: bool = False) -> torch.Tensor:
|
||||
|
||||
if use_compile:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
use_inductor=use_inductor,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
if split_attn:
|
||||
@ -304,7 +306,7 @@ def run_model(llama_config,
|
||||
return output.cpu()
|
||||
|
||||
|
||||
def test_toy_llama():
|
||||
def _test_toy_llama(*, use_inductor):
|
||||
# compare output with and without piecewise compilation
|
||||
|
||||
llama_config = LlamaConfig(hidden_size=128,
|
||||
@ -326,8 +328,14 @@ def test_toy_llama():
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_caputured=0,
|
||||
):
|
||||
outputs.append(run_model(llama_config, use_compile=False))
|
||||
run_model(tractable_config, use_compile=False)
|
||||
outputs.append(
|
||||
run_model(llama_config, use_inductor=False, use_compile=False))
|
||||
run_model(tractable_config, use_inductor=False, use_compile=False)
|
||||
|
||||
if use_inductor:
|
||||
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
|
||||
else:
|
||||
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
@ -336,9 +344,13 @@ def test_toy_llama():
|
||||
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_caputured=
|
||||
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
**kwargs,
|
||||
):
|
||||
outputs.append(run_model(llama_config, use_compile=True))
|
||||
run_model(tractable_config, use_compile=True)
|
||||
outputs.append(
|
||||
run_model(llama_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True))
|
||||
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
@ -353,13 +365,27 @@ def test_toy_llama():
|
||||
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
outputs.append(
|
||||
run_model(llama_config, use_compile=True, split_attn=True))
|
||||
run_model(tractable_config, use_compile=True, split_attn=True)
|
||||
run_model(llama_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True,
|
||||
split_attn=True))
|
||||
run_model(tractable_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True,
|
||||
split_attn=True)
|
||||
|
||||
for i in range(1, len(outputs)):
|
||||
assert torch.allclose(outputs[0], outputs[i])
|
||||
|
||||
|
||||
def test_toy_llama_inductor():
|
||||
_test_toy_llama(use_inductor=True)
|
||||
|
||||
|
||||
def test_toy_no_inductor():
|
||||
_test_toy_llama(use_inductor=False)
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def benchmark():
|
||||
from triton.testing import do_bench
|
||||
|
||||
Reference in New Issue
Block a user