[torch.compile] CUDAGraph Inductor partition integration (#24281)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@ -15,6 +15,7 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from ..silly_attention import get_global_counter, reset_global_counter
|
||||
@ -50,16 +51,21 @@ class SillyModel(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_inductor", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_simple_piecewise_compile(use_inductor):
|
||||
assert VLLM_USE_V1
|
||||
|
||||
def _run_simple_model(
|
||||
splitting_ops,
|
||||
use_inductor_graph_partition,
|
||||
use_inductor,
|
||||
expected_num_piecewise_graphs_seen,
|
||||
expected_num_piecewise_capturable_graphs_seen,
|
||||
expected_num_backend_compilations,
|
||||
expected_num_cudagraph_captured,
|
||||
):
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
use_inductor=use_inductor,
|
||||
splitting_ops=["silly.attention"],
|
||||
splitting_ops=splitting_ops,
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
cudagraph_copy_inputs=True,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
@ -70,11 +76,11 @@ def test_simple_piecewise_compile(use_inductor):
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
||||
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
||||
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=
|
||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||
num_piecewise_capturable_graphs_seen=
|
||||
expected_num_piecewise_capturable_graphs_seen,
|
||||
num_backend_compilations=expected_num_backend_compilations,
|
||||
num_cudagraph_captured=expected_num_cudagraph_captured,
|
||||
), set_forward_context(None,
|
||||
vllm_config=vllm_config): # background context
|
||||
# warm up with background context
|
||||
@ -104,3 +110,46 @@ def test_simple_piecewise_compile(use_inductor):
|
||||
output = model(input)
|
||||
assert get_global_counter() == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_inductor", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_simple_piecewise_compile(use_inductor):
|
||||
assert VLLM_USE_V1
|
||||
_run_simple_model(
|
||||
splitting_ops=["silly.attention"],
|
||||
use_inductor_graph_partition=False,
|
||||
use_inductor=use_inductor,
|
||||
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
||||
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
||||
expected_num_backend_compilations=
|
||||
3, # num_piecewise_capturable_graphs_seen
|
||||
expected_num_cudagraph_captured=
|
||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
|
||||
def test_simple_inductor_graph_partition(splitting_ops):
|
||||
assert VLLM_USE_V1
|
||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available "
|
||||
"in PyTorch 2.9+")
|
||||
|
||||
_run_simple_model(
|
||||
# inductor graph partition automatically resets splitting_ops
|
||||
# to be an empty list
|
||||
splitting_ops=splitting_ops,
|
||||
use_inductor_graph_partition=True,
|
||||
use_inductor=True,
|
||||
expected_num_piecewise_graphs_seen=
|
||||
1, # since not splitting at fx graph level
|
||||
expected_num_piecewise_capturable_graphs_seen=
|
||||
1, # since not splitting at fx graph level
|
||||
expected_num_backend_compilations=
|
||||
1, # since not splitting at fx graph level
|
||||
expected_num_cudagraph_captured=
|
||||
6, # inductor graph partition still captures 6
|
||||
# graph, same as fx graph partition.
|
||||
)
|
||||
|
||||
@ -60,4 +60,5 @@ direct_register_custom_op(
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
tags=(torch._C.Tag.cudagraph_unsafe, ),
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@ -10,9 +11,13 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from tests.v1.attention.utils import _Backend
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
PassConfig)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
@ -105,6 +110,18 @@ def test_full_graph(
|
||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
debug_dump_path=tempfile.gettempdir()),
|
||||
("facebook/opt-125m", {})),
|
||||
] + [
|
||||
# graph inductor partition
|
||||
(
|
||||
CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
# inductor graph partition uses
|
||||
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
|
||||
use_inductor_graph_partition=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
compile_sizes=[1, 2]),
|
||||
model) for model in models_list(all=False)
|
||||
if is_torch_equal_or_newer("2.9.0.dev")
|
||||
])
|
||||
# only test some of the models
|
||||
@create_new_process_for_each_test()
|
||||
@ -112,11 +129,51 @@ def test_custom_compile_config(
|
||||
compilation_config: CompilationConfig,
|
||||
model_info: tuple[str, dict[str, Any]],
|
||||
):
|
||||
if (compilation_config.use_inductor_graph_partition
|
||||
and not is_torch_equal_or_newer("2.9.0.dev")):
|
||||
pytest.skip("inductor graph partition is only available "
|
||||
"in PyTorch 2.9+")
|
||||
|
||||
model, model_kwargs = model_info
|
||||
print(f"MODEL={model}")
|
||||
run_model(compilation_config, model, model_kwargs)
|
||||
|
||||
|
||||
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available "
|
||||
"in PyTorch 2.9+")
|
||||
|
||||
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_inductor_graph_partition=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
custom_ops=["+quant_fp8"],
|
||||
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
|
||||
)
|
||||
model_kwargs = {
|
||||
"kv_cache_dtype": "fp8",
|
||||
"max_model_len": 1024,
|
||||
}
|
||||
with caplog_vllm.at_level(
|
||||
logging.DEBUG), global_force_attn_backend_context_manager(
|
||||
_Backend.FLASHINFER):
|
||||
run_model(compilation_config, model, model_kwargs)
|
||||
|
||||
try:
|
||||
assert ("Fused quantization onto 48 attention nodes"
|
||||
in caplog_vllm.text), caplog_vllm.text
|
||||
except AssertionError:
|
||||
# Note: this message is only triggered when the compilation goes
|
||||
# through the custom pass. Due to multiple layers of cache on
|
||||
# PyTorch side, the compilation of a graph may be cached such
|
||||
# that custom pass directly goes through cache. In this case,
|
||||
# we go through this branch and assert that the pass is not
|
||||
# triggered.
|
||||
assert "Fused quantization" not in caplog_vllm.text
|
||||
|
||||
|
||||
def run_model(compile_config: Union[int, CompilationConfig], model: str,
|
||||
model_kwargs: dict[str, Any]):
|
||||
prompts = [
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -339,6 +340,10 @@ else:
|
||||
@pytest.mark.parametrize(
|
||||
"split_attention",
|
||||
[False, True] if current_platform.is_rocm() else [False])
|
||||
# TODO(boyuan): test inductor graph partition on rocm
|
||||
@pytest.mark.parametrize(
|
||||
"use_inductor_graph_partition",
|
||||
[False] if current_platform.is_rocm() else [False, True])
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Only test ROCm or CUDA")
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@ -352,9 +357,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
dtype: torch.dtype, model_name: str,
|
||||
model_class: type[AttentionQuantPatternModel],
|
||||
backend: _Backend, split_attention: bool,
|
||||
monkeypatch, dist_init):
|
||||
use_inductor_graph_partition: bool,
|
||||
monkeypatch, dist_init, caplog_vllm):
|
||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer(
|
||||
"2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available "
|
||||
"in PyTorch 2.9+")
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
if split_attention:
|
||||
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
|
||||
@ -372,6 +383,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+quant_fp8"],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
),
|
||||
cache_config=CacheConfig(cache_dtype="fp8"))
|
||||
|
||||
@ -444,6 +456,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
backend=test_backend,
|
||||
fullgraph=True)
|
||||
assert model_compiled.attn._o_scale_float is None
|
||||
|
||||
result_fused_1 = model_compiled(q, k, v)
|
||||
|
||||
if backend == _Backend.FLASHINFER:
|
||||
@ -453,6 +466,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
# _o_scale_float
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
result_fused_2 = model_compiled(q, k, v)
|
||||
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
|
||||
torch.testing.assert_close(result_unfused,
|
||||
|
||||
Reference in New Issue
Block a user