[torch.compile] Enable attention and allreduce fusion without custom ops enabled (#24604)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Luka Govedič
2025-10-17 10:10:23 -04:00
committed by GitHub
parent be429d0cfd
commit bd7157a071
28 changed files with 1519 additions and 721 deletions

View File

@ -416,8 +416,8 @@ steps:
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s compile/piecewise/
- label: PyTorch Fullgraph Test # 20min
timeout_in_minutes: 30
- label: PyTorch Fullgraph Test # 22min
timeout_in_minutes: 35
mirror_hardwares: [amdexperimental]
torch_nightly: true
source_file_dependencies:
@ -425,6 +425,7 @@ steps:
- tests/compile
commands:
- pytest -v -s compile/test_full_graph.py
- pytest -v -s compile/test_fusions_e2e.py
- label: Kernels Core Operation Test # 48min
timeout_in_minutes: 75
@ -807,8 +808,8 @@ steps:
# Whisper needs spawn method to avoid deadlock
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
- label: Blackwell Test # 38 min
timeout_in_minutes: 60
- label: Blackwell Test # 21 min
timeout_in_minutes: 30
working_dir: "/vllm-workspace/"
gpu: b200
# optional: true
@ -821,8 +822,6 @@ steps:
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/fusion.py
- vllm/compilation/fusion_attn.py
commands:
- nvidia-smi
- python3 examples/offline_inference/basic/chat.py
@ -839,15 +838,32 @@ steps:
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
# Fusion
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
- pytest -v -s tests/kernels/moe/test_flashinfer.py
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py
- label: Blackwell Fusion Tests # 30 min
timeout_in_minutes: 40
working_dir: "/vllm-workspace/"
gpu: b200
source_file_dependencies:
- csrc/quantization/fp4/
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/
# can affect pattern matching
- vllm/model_executor/layers/layernorm.py
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
commands:
- nvidia-smi
- pytest -v -s tests/compile/test_fusion_attn.py
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
# this runner has 2 GPUs available even though num_gpus=2 is not set
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusions_e2e.py
- label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60
@ -1100,7 +1116,7 @@ steps:
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
##### H200 test #####
- label: Distrubted Tests (H200) # optional
- label: Distributed Tests (H200) # optional
gpu: h200
optional: true
working_dir: "/vllm-workspace/"
@ -1108,6 +1124,8 @@ steps:
commands:
- pytest -v -s tests/compile/test_async_tp.py
- pytest -v -s tests/compile/test_sequence_parallelism.py
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
- pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048

View File

@ -392,6 +392,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
TORCH_CHECK(input.scalar_type() == residual.scalar_type());
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);

View File

@ -229,6 +229,8 @@ void fused_add_rms_norm_static_fp8_quant(
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(residual.scalar_type() == input.scalar_type());
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
int hidden_size = input.size(-1);
int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;

View File

@ -145,7 +145,11 @@ void rms_norm_dynamic_per_token_quant(
if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
}
TORCH_CHECK(weight.dtype() == input.dtype());
TORCH_CHECK(scales.dtype() == torch::kFloat32);
if (residual) {
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
}
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {

View File

@ -3,16 +3,22 @@
import weakref
from collections.abc import Callable, Sequence
from contextlib import nullcontext
from copy import deepcopy
import depyf
from torch import fx
from torch._ops import OpOverload
from torch.fx._utils import lazy_format_graph_code
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass
from vllm.compilation.pass_manager import with_pattern_match_debug
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
logger = init_logger("vllm.tests.compile.backend")
class LazyInitPass(InductorPass):
@ -45,20 +51,32 @@ class TestBackend:
def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]):
self.custom_passes = list(passes)
compile_config = get_current_vllm_config().compilation_config
self.inductor_config = compile_config.inductor_compile_config
vllm_config = get_current_vllm_config()
compile_config = vllm_config.compilation_config
# Deepcopy to allow multiple TestBackend instances to use the same VllmConfig
self.inductor_config = deepcopy(compile_config.inductor_compile_config)
self.inductor_config["force_disable_caches"] = True
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
if debug_dump_path := vllm_config.compile_debug_dump_path():
logger.debug("Dumping depyf output to %s", debug_dump_path)
self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix())
else:
self.debug_ctx = nullcontext()
def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx
return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
with self.debug_ctx:
return compile_fx(
graph, example_inputs, config_patches=self.inductor_config
)
@with_pattern_match_debug
def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)
lazy_format_graph_code("graph_pre_pass", graph.owning_module)
VllmInductorPass.dump_prefix = 0
for pass_ in self.custom_passes:
@ -68,6 +86,7 @@ class TestBackend:
VllmInductorPass.dump_prefix = None
self.graph_post_pass = deepcopy(graph)
lazy_format_graph_code("graph_post_pass", graph.owning_module)
# assign by reference, will reflect the final state of the graph
self.final_graph = graph

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import tempfile
from pathlib import Path
from typing import Any
import pytest
@ -10,8 +10,6 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
@ -22,23 +20,24 @@ from ..utils import create_new_process_for_each_test
def models_list(*, all: bool = True, keywords: list[str] | None = None):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
{
"dtype": torch.float16,
},
),
(
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
{
"dtype": torch.float16,
},
{"dtype": torch.float16},
),
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
("meta-llama/Llama-3.2-1B-Instruct", {}),
]
if all:
TEST_MODELS.extend(
[
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
{"dtype": torch.float16},
),
]
)
# TODO: figure out why this fails.
if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append(
@ -83,31 +82,38 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None):
"compilation_mode",
[CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE],
)
@pytest.mark.parametrize("model_info", models_list(all=True))
@pytest.mark.parametrize("model, model_kwargs", models_list(all=True))
@create_new_process_for_each_test()
def test_full_graph(
monkeypatch: pytest.MonkeyPatch,
model_info: tuple[str, dict[str, Any]],
model: str,
model_kwargs: dict[str, Any],
compilation_mode: int,
):
model, model_kwargs = model_info
if (
"w8a8" in model
or "w8w8" in model
and current_platform.has_device_capability((10, 0))
):
# int8 removed on Blackwell:
pytest.skip("int8 support removed on Blackwell")
with monkeypatch.context():
print(f"MODEL={model}")
run_model(compilation_mode, model, model_kwargs)
run_model(compilation_mode, model, **model_kwargs)
# TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize(
"compilation_config, model_info",
"compilation_config, model, model_kwargs",
[
# additional compile sizes, only some of the models
(
CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]),
model,
*model_info,
)
for model in models_list(all=False)
for model_info in models_list(all=False)
]
+ [
# RMSNorm + quant fusion, only 8-bit quant models
@ -117,18 +123,19 @@ def test_full_graph(
custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
),
model,
*model_info,
)
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
]
+ [
# Test depyf integration works
(
CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
debug_dump_path=tempfile.gettempdir(),
debug_dump_path=Path(tempfile.gettempdir()),
),
("facebook/opt-125m", {}),
"facebook/opt-125m",
{},
),
]
+ [
@ -142,9 +149,9 @@ def test_full_graph(
cudagraph_mode=CUDAGraphMode.PIECEWISE,
compile_sizes=[1, 2],
),
model,
*model_info,
)
for model in models_list(all=False)
for model_info in models_list(all=False)
if is_torch_equal_or_newer("2.9.0.dev")
],
)
@ -152,16 +159,24 @@ def test_full_graph(
@create_new_process_for_each_test()
def test_custom_compile_config(
compilation_config: CompilationConfig,
model_info: tuple[str, dict[str, Any]],
model: str,
model_kwargs: dict[str, Any],
):
if (
"w8a8" in model
or "w8w8" in model
and current_platform.has_device_capability((10, 0))
):
# int8 removed on Blackwell:
pytest.skip("int8 support removed on Blackwell")
if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev"
):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
model, model_kwargs = model_info
print(f"MODEL={model}")
run_model(compilation_config, model, model_kwargs)
run_model(compilation_config, model, **model_kwargs)
@pytest.mark.parametrize(
@ -176,50 +191,16 @@ def test_fp8_kv_scale_compile(compilation_mode: int):
"calculate_kv_scales": True,
"max_model_len": 512,
}
run_model(compilation_mode, model, model_kwargs)
run_model(compilation_mode, model, **model_kwargs)
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
compilation_config = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
custom_ops=["+quant_fp8"],
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
compilation_config = (
compile_config
if isinstance(compile_config, CompilationConfig)
else CompilationConfig(level=compile_config)
)
model_kwargs = {
"kv_cache_dtype": "fp8",
"max_model_len": 1024,
}
with (
caplog_vllm.at_level(logging.DEBUG),
global_force_attn_backend_context_manager(_Backend.FLASHINFER),
):
run_model(compilation_config, model, model_kwargs)
try:
assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, (
caplog_vllm.text
)
except AssertionError:
# Note: this message is only triggered when the compilation goes
# through the custom pass. Due to multiple layers of cache on
# PyTorch side, the compilation of a graph may be cached such
# that custom pass directly goes through cache. In this case,
# we go through this branch and assert that the pass is not
# triggered.
assert "Fused quantization" not in caplog_vllm.text
def run_model(
compile_config: int | CompilationConfig,
model: str,
model_kwargs: dict[str, Any],
):
prompts = [
"Hello, my name is",
"The president of the United States is",
@ -227,12 +208,17 @@ def run_model(
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
# Allow override from model_kwargs
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
# No cudagraphs by default
if compilation_config.cudagraph_mode is None:
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
llm = LLM(
model=model,
enforce_eager=True,
tensor_parallel_size=1,
disable_custom_all_reduce=True,
compilation_config=compile_config,
compilation_config=compilation_config,
**model_kwargs,
)
outputs = llm.generate(prompts, sampling_params)

View File

@ -11,7 +11,13 @@ from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.config import (
CompilationConfig,
ModelConfig,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
@ -48,8 +54,7 @@ class TestSiluMul(torch.nn.Module):
return y
def example_inputs(self, num_tokens=32, hidden_size=128):
dtype = torch.float16 if TEST_FP8 else torch.float32
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
return (torch.rand(num_tokens, hidden_size * 2),)
def ops_in_model(self, do_fusion):
if TEST_FP8 and do_fusion:
@ -67,15 +72,11 @@ class TestFusedAddRMSNorm(torch.nn.Module):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
dtype = torch.float16 if TEST_FP8 else torch.float32
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size), dtype=dtype)
torch.empty((intermediate_size, hidden_size))
)
self.norm = RMSNorm(intermediate_size, 1e-05)
self.norm.weight = torch.nn.Parameter(
torch.ones(intermediate_size, dtype=dtype)
)
self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
torch.nn.init.normal_(self.gate_proj, std=0.02)
@ -112,9 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
return norm_output, residual_output
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
dtype = torch.float16 if TEST_FP8 else torch.float32
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
hidden_states = torch.randn((batch_size * seq_len, hidden_size))
residual = torch.randn((batch_size * seq_len, hidden_size))
return (hidden_states, residual)
def ops_in_model(self, do_fusion):
@ -145,10 +145,9 @@ class TestRotaryEmbedding(torch.nn.Module):
return q_rotated, k_rotated
def example_inputs(self, num_tokens=32, head_dim=64):
dtype = torch.float16
positions = torch.arange(num_tokens, dtype=torch.long)
q = torch.randn(num_tokens, head_dim, dtype=dtype)
k = torch.randn(num_tokens, head_dim, dtype=dtype)
q = torch.randn(num_tokens, head_dim)
k = torch.randn(num_tokens, head_dim)
return (positions, q, k)
def ops_in_model(self, do_fusion):
@ -166,7 +165,7 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
self.hidden_size = head_dim * num_heads
self.qkv_proj = torch.nn.Linear(
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
self.hidden_size, self.hidden_size * 3, bias=False
)
self.rotary_emb = get_rope(
@ -190,10 +189,9 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
return qkv_updated
def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
dtype = torch.float16
hidden_size = head_dim * num_heads
positions = torch.arange(num_tokens, dtype=torch.long)
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
hidden_states = torch.randn(num_tokens, hidden_size)
return (positions, hidden_states)
def ops_in_model(self, do_fusion):
@ -211,48 +209,58 @@ MODELS = [
]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("model_class", MODELS)
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
def test_fix_functionalization(
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
custom_ops=["all"],
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
),
)
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = (
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion
else [noop_pass, cleanup_pass]
)
func_pass = FixFunctionalizationPass(vllm_config)
with set_current_vllm_config(vllm_config):
assert RMSNorm.enabled()
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
passes = (
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion
else [noop_pass, cleanup_pass]
)
func_pass = FixFunctionalizationPass(vllm_config)
model = model_class()
torch.compile(model, backend=backend_func)(*model.example_inputs())
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
model = model_class()
torch.compile(model, backend=backend_func)(*model.example_inputs())
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
if is_func(node, op):
found[op] = True
for op in model.ops_not_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model(do_fusion))
assert all(not found.get(op) for op in model.ops_not_in_model())
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in model.ops_in_model(do_fusion):
if is_func(node, op):
found[op] = True
for op in model.ops_not_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model(do_fusion))
assert all(not found.get(op) for op in model.ops_not_in_model())

View File

@ -5,15 +5,18 @@ import pytest
import torch
import vllm.plugins
from vllm.compilation.fusion import (
FUSED_OPS,
QUANT_OPS,
FusedRMSQuantKey,
RMSNormQuantFusionPass,
)
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
from vllm.config import (
CompilationConfig,
CompilationMode,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@ -32,6 +35,9 @@ from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
class TestModel(torch.nn.Module):
def __init__(
@ -45,18 +51,18 @@ class TestModel(torch.nn.Module):
):
super().__init__(*args, **kwargs)
self.cuda_force_torch = cuda_force_torch
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
else:
self.scale = [None for _ in range(2)]
self.scale = [None for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2)
for _ in range(3)
]
with override_cutlass_fp8_supported(not cuda_force_torch):
@ -65,8 +71,12 @@ class TestModel(torch.nn.Module):
act_quant_group_shape=group_shape,
)
self.enable_rms_norm_custom_op = self.norm[0].enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
def forward(self, x):
resid = torch.sqrt(x)
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = self.norm[0](x)
x2 = self.fp8_linear.apply(
@ -78,24 +88,44 @@ class TestModel(torch.nn.Module):
x3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
y3, resid = self.norm[2](x3, resid) # use resid here
return y3
def ops_in_model_before(self):
return [QUANT_OPS[self.key]]
y3, resid = self.norm[2](x3, resid) # use resid here
x4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self):
return [
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
FUSED_OPS[FusedRMSQuantKey(self.key, True)],
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
]
def ops_in_model_before(self):
return (
[QUANT_OPS[self.quant_key]]
if self.enable_quant_fp8_custom_op
else [torch.ops.aten.reciprocal]
)
def ops_in_model_before_partial(self):
return (
[RMS_OP, RMS_ADD_OP]
if self.enable_rms_norm_custom_op
else [torch.ops.aten.rsqrt]
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize(
@ -105,19 +135,32 @@ class TestModel(torch.nn.Module):
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
dtype, hidden_size, num_tokens, eps, static, cuda_force_torch
dtype,
hidden_size,
num_tokens,
eps,
static,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
cuda_force_torch,
):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
custom_ops = []
if enable_rms_norm_custom_op:
custom_ops.append("+rms_norm")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm", "+quant_fp8"],
custom_ops=custom_ops,
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
)
),
)
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
@ -126,31 +169,39 @@ def test_fusion_rmsnorm_quant(
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)
model = TestModel(hidden_size, eps, static, cuda_force_torch)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
result = model(x)
model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
model2 = torch.compile(model, backend=backend)
result2 = model2(x)
model_unfused = torch.compile(model, backend=backend2)
result_unfused = model_unfused(x)
# Higher tol for dynamic, even higher for bfloat16
if static:
ATOL, RTOL = (1e-3, 1e-3)
elif dtype == torch.float16:
if dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
assert fusion_pass.matched_count == 2
# In pre-nodes, fp8 quant should be there and fused kernels should not
assert fusion_pass.matched_count == 3
backend.check_before_ops(model.ops_in_model_before())
# In post-nodes, fused kernels should be there and fp8 quant should not
backend.check_before_ops(
model.ops_in_model_before_partial(), fully_replaced=False
)
backend.check_after_ops(model.ops_in_model_after())
# If RMSNorm custom op is disabled (native/torch impl used),
# there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op:
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7
assert n_add_nodes(backend.graph_post_pass) == 2

View File

@ -6,6 +6,7 @@ import pytest
import torch
import vllm.envs as envs
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
@ -17,6 +18,7 @@ from vllm.config import (
ModelConfig,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
@ -25,8 +27,8 @@ from vllm.distributed.parallel_state import (
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
GroupShape,
QuantFP8,
)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
@ -40,13 +42,30 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm = self.norm(all_reduce)
return norm
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(x)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = torch.mm(y, self.w[0])
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = torch.mm(y2, self.w[1])
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid)
z4 = torch.mm(y3, self.w[2])
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid)
return y4
def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
@ -55,44 +74,53 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
.t()
for _ in range(3)
]
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm, _ = self.norm(all_reduce, residual)
return norm
def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
torch.ops._C.static_scaled_fp8_quant(
self.output, norm_output.contiguous(), self.scale
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
return self.output, residual_output
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
@ -100,7 +128,9 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default,
torch.ops._C.static_scaled_fp8_quant.default
if self.fp8_linear.quant_fp8.enabled()
else torch.ops.aten.reciprocal.default,
]
@ -109,25 +139,48 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4)
self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32)
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)]
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
torch.ops._C.scaled_fp4_quant(
self.output, norm_output, self.output_scale, self.scale
wq_gen, wscale_gen = zip(
*(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
)
return self.output, residual_output, self.output_scale
self.wq, self.wscale = list(wq_gen), list(wscale_gen)
print(f"{self.wq=}, {self.wscale=}")
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
yq, y_scale = scaled_fp4_quant(y, self.agscale[0])
z2 = cutlass_scaled_fp4_mm(
yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype
)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1])
z3 = cutlass_scaled_fp4_mm(
yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype
)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2])
z4 = cutlass_scaled_fp4_mm(
yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype
)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
@ -141,19 +194,19 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"test_model",
"test_model, enable_quant_fp8_custom_op",
[
TestAllReduceRMSNormModel,
TestAllReduceFusedAddRMSNormModel,
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
# TODO: Enable with torch==2.8.0
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
(TestAllReduceRMSNormModel, False),
(TestAllReduceRMSNormStaticQuantFP8Model, True),
(TestAllReduceRMSNormStaticQuantFP8Model, False),
(TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
],
)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif(
not find_spec("flashinfer")
@ -167,6 +220,8 @@ def test_all_reduce_fusion_pass_replace(
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
):
num_processes = 2
if (
@ -181,7 +236,16 @@ def test_all_reduce_fusion_pass_replace(
def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(
fn,
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
args=(
num_processes,
test_model,
batch_size,
seq_len,
hidden_size,
dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
),
nprocs=nprocs,
)
@ -196,6 +260,8 @@ def all_reduce_fusion_pass_on_test_model(
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
):
current_platform.seed_everything(0)
@ -217,15 +283,22 @@ def all_reduce_fusion_pass_on_test_model(
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
custom_ops = []
if enable_rms_norm_custom_op:
custom_ops.append("+rms_norm")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm", "+quant_fp8"]
mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops
)
)
vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True, enable_noop=True
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
@ -233,24 +306,27 @@ def all_reduce_fusion_pass_on_test_model(
vllm_config.model_config = ModelConfig(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
with set_current_vllm_config(vllm_config):
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(
noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass
)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
residual = torch.randn((token_num, hidden_size), requires_grad=False)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual)
assert all_reduce_fusion_pass.matched_count == 1
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass
assert all_reduce_fusion_pass.matched_count == 4, (
f"{all_reduce_fusion_pass.matched_count=}"
)
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass

View File

@ -6,14 +6,15 @@ import pytest
import torch._dynamo
from tests.compile.backend import LazyInitPass, TestBackend
from tests.utils import flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (
@ -28,21 +29,18 @@ from vllm.config import (
)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.kv_cache_interface import AttentionSpec
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
# globals needed for string-import custom Dynamo backend field
backend: TestBackend | None = None
backend_unfused: TestBackend | None = None
class AttentionQuantPatternModel(torch.nn.Module):
"""Base model for AttentionQuantPattern fusion."""
@ -104,6 +102,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
num_blocks = batch_size * max_blocks
backend = self.attn.backend
# TODO(luka) use get_kv_cache_stride_order
# Create dummy KV cache for the selected backend
if backend == _Backend.ROCM_ATTN:
# k/v as 1st dimention
@ -241,26 +240,40 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
)
MODELS_FP8: list[tuple[str, type]] = []
MODELS_FP4: list[tuple[str, type]] = []
HEADS: list[tuple[int, int]] = []
SPLIT_ATTENTION: list[bool] = []
BACKENDS_FP8: list[_Backend] = []
BACKENDS_FP4: list[_Backend] = []
if current_platform.is_cuda():
MODELS = [
HEADS = [(64, 8), (40, 8)]
MODELS_FP8 = [
(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel,
),
)
]
MODELS_FP4 = [
(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel,
),
)
]
HEADS = [(64, 8), (40, 8)]
BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER]
BACKENDS_FP4 = [_Backend.FLASHINFER]
elif current_platform.is_rocm():
MODELS = [
HEADS = [(32, 8), (40, 8)]
MODELS_FP8 = [
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
]
HEADS = [(32, 8), (40, 8)]
else:
MODELS = []
HEADS = []
BACKENDS = [
_Backend.ROCM_AITER_UNIFIED_ATTN,
_Backend.ROCM_ATTN,
_Backend.TRITON_ATTN,
]
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
@ -269,46 +282,36 @@ else:
"batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize(
"backend",
[_Backend.FLASHINFER]
if current_platform.is_cuda()
else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN],
)
# TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize(
"use_inductor_graph_partition",
[False] if current_platform.is_rocm() else [False, True],
"backend, model_name, model_class, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list(flat_product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"]))
# quant_fp4 only has the custom impl
+ list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])),
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(
current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)),
reason="On CUDA only test on SM100(Blackwell)",
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
def test_attention_quant_pattern(
num_qo_heads: int,
num_kv_heads: int,
head_size: int,
batch_size: int,
dtype: torch.dtype,
custom_ops: str,
model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend,
use_inductor_graph_partition: bool,
dist_init,
caplog_vllm,
):
"""Test AttentionStaticQuantPattern fusion pass"""
if backend == _Backend.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
custom_ops_list = custom_ops.split(",") if custom_ops else []
device = torch.device("cuda:0")
torch.manual_seed(42)
@ -322,8 +325,7 @@ def test_attention_quant_pattern(
scheduler_config=SchedulerConfig(max_num_seqs=1024),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+quant_fp8"],
use_inductor_graph_partition=use_inductor_graph_partition,
custom_ops=custom_ops_list,
),
cache_config=CacheConfig(cache_dtype="fp8"),
)
@ -358,8 +360,9 @@ def test_attention_quant_pattern(
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
# Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v)
# Run model directly without fusion
# Still compile so query QuantFP8 has closer numerics
result_unfused = torch.compile(model_unfused, fullgraph=True)(q, k, v)
# Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
@ -414,16 +417,25 @@ def test_attention_quant_pattern(
)
# Check attn fusion support
quant_key = model_class.quant_key
quant_key: QuantKey = model_class.quant_key
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key)
for key, layer in vllm_config.compilation_config.static_forward_context.items()
]
if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion
# Note: fully_replaced=False because query quant ops remain in graph.
# Only output quant ops are fused into attention.
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
assert sum(attn_fusion_supported) == len(attn_fusion_supported), (
"All layers should support attention fusion"
)
# Check quantization ops in the graph before and after fusion
quant_op = (
torch.ops.aten.reciprocal
if "-quant_fp8" in custom_ops_list
else QUANT_OPS[quant_key]
)
# Note: for fp8, fully_replaced=False because query quant ops remain in graph.
# Only output quant ops are fused into attention.
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant)
# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)

View File

@ -0,0 +1,305 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import itertools
import logging
from collections.abc import Iterable
from typing import Any, NamedTuple
import pytest
import regex as re
from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.flashinfer import has_flashinfer
from ..utils import flat_product, multi_gpu_test
class ModelBackendTestCase(NamedTuple):
model_name: str
model_kwargs: dict[str, Any]
backend: _Backend
attention_fusions: int
allreduce_fusions: int | None = None
MODELS_FP8: list[ModelBackendTestCase] = []
MODELS_FP4: list[ModelBackendTestCase] = []
MODELS: list[ModelBackendTestCase] = [] # tp-only
if current_platform.is_cuda():
MODELS_FP8 = [
ModelBackendTestCase(
# Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=32,
allreduce_fusions=65,
),
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
),
]
MODELS_FP4 = [
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
),
]
# TP only
MODELS = [
ModelBackendTestCase(
model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=0,
allreduce_fusions=65,
),
]
elif current_platform.is_rocm():
MODELS_FP8 = [
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=32,
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_ATTN,
attention_fusions=32,
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
attention_fusions=32,
),
]
# TODO(luka) test both in nightly
CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"]
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, "
"attention_fusions, allreduce_fusions, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
# quant_fp4 only has the custom impl
+ list(flat_product(MODELS_FP4, [""])),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
def test_attn_quant(
model_name: str,
model_kwargs: dict[str, Any],
backend: _Backend,
attention_fusions: int,
allreduce_fusions: int,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if backend == _Backend.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
compilation_config = CompilationConfig(
# Testing properties
custom_ops=custom_ops_list,
use_inductor_graph_partition=inductor_graph_partition,
cudagraph_mode=mode,
splitting_ops=splitting_ops,
# Common
level=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(compilation_config, model_name, **model_kwargs)
matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(matches) == 1, log_holder.text
assert int(matches[0]) == attention_fusions
# TODO(luka) test both in nightly
CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"]
def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
for op_list in itertools.product(*custom_ops_lists):
yield ",".join(op_list)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, "
"attention_fusions, allreduce_fusions, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models
list(
flat_product(
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
)
)
# Toggle RMSNorm for FP4 models and unquant models
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
@pytest.mark.skipif(
not current_platform.is_cuda()
or not has_flashinfer()
or not current_platform.has_device_capability(90),
reason="allreduce+rmsnorm fusion requires flashinfer",
)
def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str,
model_kwargs: dict,
backend: _Backend,
attention_fusions: int,
allreduce_fusions: int,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
compilation_config = CompilationConfig(
# Testing properties
use_inductor_graph_partition=inductor_graph_partition,
cudagraph_mode=mode,
custom_ops=custom_ops_list,
splitting_ops=splitting_ops,
# Common
level=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(
enable_attn_fusion=True,
enable_noop=True,
enable_fi_allreduce_fusion=True,
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
)
matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(matches) == 2, log_holder.text
assert int(matches[0]) == attention_fusions
assert int(matches[1]) == attention_fusions
matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(matches) == 2, log_holder.text
assert int(matches[0]) == allreduce_fusions
assert int(matches[1]) == allreduce_fusions
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
compilation_config = (
compile_config
if isinstance(compile_config, CompilationConfig)
else CompilationConfig(level=compile_config)
)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
# Allow override from model_kwargs
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
# No cudagraphs by default
if compilation_config.cudagraph_mode is None:
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
llm = LLM(
model=model,
compilation_config=compilation_config,
**model_kwargs,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -7,7 +7,7 @@ import torch
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
# dummy custom pass that doesn't inherit
@ -42,7 +42,8 @@ class ProperPass(InductorPass):
],
)
def test_pass_manager_uuid(callable):
config = VllmConfig()
# Some passes need dtype to be set
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
pass_manager = PostGradPassManager()
pass_manager.configure(config)

View File

@ -18,6 +18,8 @@ from vllm.config import (
ModelConfig,
PassConfig,
VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
@ -42,9 +44,7 @@ prompts = [
class TestModel(torch.nn.Module):
def __init__(
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
):
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
@ -95,13 +95,11 @@ class TestModel(torch.nn.Module):
class TestQuantModel(torch.nn.Module):
def __init__(
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
):
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.vllm_config = vllm_config
self.vllm_config = get_current_vllm_config()
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size)), requires_grad=False
)
@ -266,76 +264,84 @@ def sequence_parallelism_pass_on_test_model(
initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
compilation_config = CompilationConfig(
pass_config=PassConfig(
enable_sequence_parallelism=True,
enable_fusion=enable_fusion,
enable_noop=True,
)
) # NoOp needed for fusion
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
vllm_config.model_config = ModelConfig(
model_config = ModelConfig(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
assert (
sequence_parallelism_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
vllm_config = VllmConfig(
model_config=model_config,
device_config=device_config,
compilation_config=compilation_config,
)
assert (
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass]
with set_current_vllm_config(vllm_config):
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
assert (
sequence_parallelism_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
passes_for_backend: list[VllmInductorPass] = [
noop_pass,
sequence_parallelism_pass,
]
if enable_fusion:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)
if enable_fusion:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)
passes_for_backend.append(cleanup_pass)
passes_for_backend.append(cleanup_pass)
backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config)
model = test_model_cls(hidden_size, hidden_size * 2)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
compiled_model_no_func(hidden_states, residual)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
compiled_model_no_func(hidden_states, residual)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
assert sequence_parallelism_pass.matched_count == 1
assert sequence_parallelism_pass.matched_count == 1
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before())
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before())
# In post-nodes, reduce scatter and all gather should be there,
# all reduce should not
backend_no_func.check_after_ops(model.ops_in_model_after())
# In post-nodes, reduce scatter and all gather should be there,
# all reduce should not
backend_no_func.check_after_ops(model.ops_in_model_after())
# check if the functionalization pass is applied
for op in model.ops_in_model():
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
# check if the functionalization pass is applied
for op in model.ops_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model())
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in model.ops_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model())

View File

@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
import contextlib
import pathlib
from copy import deepcopy
from tblib import pickling_support
# ruff: noqa
# Install support for pickling exceptions so that we can nicely propagate
# failures from tests running in a subprocess.
# This should be run before any custom exception subclasses are defined.
@ -40,7 +43,7 @@ from transformers import (
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
from vllm import LLM, SamplingParams
from vllm import LLM, SamplingParams, envs
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
@ -1070,6 +1073,101 @@ def caplog_vllm(temporary_enable_log_propagate, caplog):
yield caplog
@pytest.fixture()
def caplog_mp_fork():
"""
This fixture enables capturing logs from a forked MP subprocess.
It should be used in conjunction with caplog_vllm.
By default, subprocess logs do not go through the parent process.
We instead create a queue listener in the parent process which
forwards logs to the logger's other handlers, and add a QueueHandler
to the root logger. Forked subprocesses will inherit the root logger
and pass their messages to the queue, which the listener will forward
to the root logger, which can be captured by caplog.
Note that this workaround only works for fork; with spawn, the subprocess
reinitializes logging and does not automatically inherit the queue.
We'd have to manually pass the queue to the subprocess at the spawn point.
See caplog_mp_spawn below.
"""
@contextlib.contextmanager
def ctx():
import logging.handlers
import multiprocessing as mp
logger_queue: mp.Queue[logging.LogRecord] = mp.Queue()
logger = logging.getLogger()
handlers = logger.handlers
# The listener works on a background thread, not inherited by the child.
queue_listener = logging.handlers.QueueListener(logger_queue, *handlers)
queue_listener.start()
# Add queue handler after creating the listener to avoid cycle
logger.addHandler(logging.handlers.QueueHandler(logger_queue))
yield
queue_listener.stop()
return ctx
class LogHolder:
def __init__(self):
self.text = None
@pytest.fixture()
def caplog_mp_spawn(tmp_path, monkeypatch):
"""
This fixture enables capturing logs from a forked MP subprocess.
It does not require caplog_vllm (but it only contains logs from the child).
By default, subprocess logs do not go through the parent process.
We instead add a FileHandler to the config so the spawned child process
writes its logs to a temp file.
In the parent, we read the file and return the contents.
Note: this method could be extended to fork by either reconfiguring logging
in the parent or using a SocketHandler:
https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network # noqa: E501
"""
@contextlib.contextmanager
def ctx(level: int | str):
from vllm.logger import DEFAULT_LOGGING_CONFIG
config_path = tmp_path / "vllm_logging_config.json"
log_path = tmp_path / "vllm.log"
log_holder = LogHolder()
config = deepcopy(DEFAULT_LOGGING_CONFIG)
if envs.VLLM_LOGGING_CONFIG_PATH:
path = pathlib.Path(envs.VLLM_LOGGING_CONFIG_PATH)
assert path.exists()
config = json.loads(path.read_text())
config["loggers"]["vllm"]["handlers"] += ["vllm_file"]
config["handlers"]["vllm_file"] = {
"class": "logging.FileHandler",
"formatter": "vllm",
"level": level,
"filename": log_path.as_posix(),
}
config_path.write_text(json.dumps(config))
with monkeypatch.context() as monkeypatch_ctx:
monkeypatch_ctx.setenv("VLLM_LOGGING_CONFIG_PATH", config_path.as_posix())
monkeypatch_ctx.setenv("VLLM_CONFIGURE_LOGGING", "1")
yield log_holder
log_holder.text = log_path.read_text()
return ctx
@pytest.fixture(scope="session")
def num_gpus_available():
"""Get number of GPUs without initializing the CUDA context

View File

@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant(
.clamp(fp8_traits_min, fp8_traits_max)
.to(FP8_DTYPE)
)
return ref_out, ref_scale.view((1,))
return ref_out, ref_scale.view((1, 1))
def native_w8a8_block_matmul(

View File

@ -501,3 +501,49 @@ def test_streaming_complete_logs_full_text_content():
assert call_args[1] == "test-streaming-full-text"
assert call_args[2] == " (streaming complete)"
assert call_args[5] == "streaming_complete"
# Add vllm prefix to make sure logs go through the vllm logger
test_logger = init_logger("vllm.test_logger")
def mp_function(**kwargs):
# This function runs in a subprocess
test_logger.warning("This is a subprocess: %s", kwargs.get("a"))
test_logger.error("This is a subprocess error.")
test_logger.debug("This is a subprocess debug message: %s.", kwargs.get("b"))
def test_caplog_mp_fork(caplog_vllm, caplog_mp_fork):
with caplog_vllm.at_level(logging.DEBUG), caplog_mp_fork():
import multiprocessing
ctx = multiprocessing.get_context("fork")
p = ctx.Process(
target=mp_function,
name=f"SubProcess{1}",
kwargs={"a": "AAAA", "b": "BBBBB"},
)
p.start()
p.join()
assert "AAAA" in caplog_vllm.text
assert "BBBBB" in caplog_vllm.text
def test_caplog_mp_spawn(caplog_mp_spawn):
with caplog_mp_spawn(logging.DEBUG) as log_holder:
import multiprocessing
ctx = multiprocessing.get_context("spawn")
p = ctx.Process(
target=mp_function,
name=f"SubProcess{1}",
kwargs={"a": "AAAA", "b": "BBBBB"},
)
p.start()
p.join()
assert "AAAA" in log_holder.text
assert "BBBBB" in log_holder.text

View File

@ -6,6 +6,7 @@ import contextlib
import copy
import functools
import importlib
import itertools
import json
import os
import random
@ -15,7 +16,7 @@ import sys
import tempfile
import time
import warnings
from collections.abc import Callable
from collections.abc import Callable, Iterable
from contextlib import ExitStack, contextmanager, suppress
from multiprocessing import Process
from pathlib import Path
@ -1261,3 +1262,23 @@ def check_answers(
frac_ok = numok / len(answer)
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
assert frac_ok >= accept_rate
def flat_product(*iterables: Iterable[Any]):
"""
Flatten lists of tuples of the cartesian product.
Useful when we want to avoid nested tuples to allow
test params to be unpacked directly from the decorator.
Example:
flat_product([(1, 2), (3, 4)], ["a", "b"]) ->
[
(1, 2, "a"),
(1, 2, "b"),
(3, 4, "a"),
(3, 4, "b"),
]
"""
for element in itertools.product(*iterables):
normalized = (e if isinstance(e, tuple) else (e,) for e in element)
yield tuple(itertools.chain(*normalized))

View File

@ -40,7 +40,7 @@ from vllm.utils import (
unique_filepath,
)
from ..utils import create_new_process_for_each_test
from ..utils import create_new_process_for_each_test, flat_product
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
@ -771,3 +771,25 @@ def test_unique_filepath():
paths.add(path)
assert len(paths) == 10
assert len(list(Path(temp_dir).glob("*.txt"))) == 10
def test_flat_product():
# Check regular itertools.product behavior
result1 = list(flat_product([1, 2, 3], ["a", "b"]))
assert result1 == [
(1, "a"),
(1, "b"),
(2, "a"),
(2, "b"),
(3, "a"),
(3, "b"),
]
# check that the tuples get flattened
result2 = list(flat_product([(1, 2), (3, 4)], ["a", "b"], [(5, 6)]))
assert result2 == [
(1, 2, "a", 5, 6),
(1, 2, "b", 5, 6),
(3, 4, "a", 5, 6),
(3, 4, "b", 5, 6),
]

View File

@ -1507,7 +1507,7 @@ def scaled_fp8_quant(
output, input, scale, scale_ub
)
else:
scale = torch.empty(1, device=input.device, dtype=torch.float32)
scale = torch.empty((1, 1), device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"

View File

@ -17,10 +17,14 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE = current_platform.fp8_dtype()
@ -41,11 +45,8 @@ else:
logger = init_logger(__name__)
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
if hasattr(torch.ops._C, "scaled_fp4_quant"):
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
class BasePattern:
@ -669,33 +670,24 @@ class AllReduceRMSNormPattern(BasePattern):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self):
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
rms_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
input, weight = self.rmsnorm_matcher.inputs()
return [input, rms_result, weight]
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor
):
def pattern(input: torch.Tensor, weight: torch.Tensor):
allreduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized(
RMS_OP,
result=rms_result,
input=allreduce_output,
weight=weight,
epsilon=self.epsilon,
)
# rms_result, allreduce_output
return rms[1], allreduce_output
rms = self.rmsnorm_matcher(allreduce_output, weight)
def replacement(
input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor
):
return rms, allreduce_output
def replacement(input: torch.Tensor, weight: torch.Tensor):
residual = torch.zeros_like(input)
rms_result = torch.empty_like(input)
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
@ -733,29 +725,19 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self):
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [
residual,
input,
weight,
]
input, residual, weight = self.rmsnorm_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [residual, input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor):
allreduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized(
RMS_ADD_OP,
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
# input, residual
return rms[1], rms[2]
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
return rms, residual
def replacement(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
@ -779,6 +761,18 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
# Same pattern, but only return the output and not residual
# (helpful for end of graph where residual is not used again)
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
pm.register_replacement(
first_return_only(pattern),
first_return_only(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
"""
@ -799,60 +793,37 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
rmsnorm_result = torch.empty(
[1, 8, 4], device=self.device, dtype=self.dtype
)
quant_result = torch.empty(
[1, 8, 4], device=self.device, dtype=self.quant_dtype
)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, rmsnorm_result, quant_result, weight, scale]
input, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight, scale]
def pattern(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
all_reduce = tensor_model_parallel_all_reduce(input)
rmsnorm_out_tuple = auto_functionalized(
RMS_OP,
result=rmsnorm_result,
input=all_reduce,
weight=weight,
epsilon=self.epsilon,
)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce
quant_out_tuple = auto_functionalized(
STATIC_FP8_QUANT_OP,
result=quant_result,
input=rmsnorm_out_tuple[1],
scale=scale,
)
# quant_out, allreduce_output
return quant_out_tuple[1], all_reduce
def replacement(
input: torch.Tensor,
result_rms: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=result_rms,
quant_out=quant_result,
quant_out=result_quant,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
@ -892,64 +863,42 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
input, residual, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
quant_result = torch.empty(
[4, 4], device=self.device, dtype=self.quant_dtype
)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [
quant_result,
residual,
input,
weight,
scale,
]
# input goes through allreduce first, always 16-bit
return [residual, input.to(self.dtype), weight, scale]
def pattern(
quant_result: torch.Tensor,
residual: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant, _ = self.quant_matcher(rms, scale)
fused_add_rmsnorm_out_tuple = auto_functionalized(
RMS_ADD_OP,
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
quant_out_tuple = auto_functionalized(
STATIC_FP8_QUANT_OP,
result=quant_result,
input=fused_add_rmsnorm_out_tuple[1],
scale=scale,
)
# quant_out, allreduce_output
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2]
return quant, res
def replacement(
quant_result: torch.Tensor,
residual: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=quant_result,
quant_out=result_quant,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
@ -986,14 +935,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
rmsnorm_result = torch.empty(
[1, 16, 16], device=self.device, dtype=self.dtype
)
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
input_global_scale = torch.empty(
[1, 1], device=self.device, dtype=torch.float32
@ -1001,36 +947,21 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
weight = torch.empty([16], device=self.device, dtype=self.dtype)
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
return [
input,
rmsnorm_result,
quant_result,
weight,
input_global_scale,
output_scale,
]
return [input, quant_result, weight, input_global_scale, output_scale]
def pattern(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
):
all_reduce = tensor_model_parallel_all_reduce(input)
rmsnorm_out_tuple = auto_functionalized(
RMS_OP,
result=rmsnorm_result,
input=all_reduce,
weight=weight,
epsilon=self.epsilon,
)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=rmsnorm_out_tuple[1],
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
)
@ -1040,13 +971,13 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
def replacement(
input: torch.Tensor,
result_rms: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
):
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
@ -1090,6 +1021,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
@ -1121,28 +1053,17 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
input_global_scale: torch.Tensor,
):
allreduce_output = tensor_model_parallel_all_reduce(input)
fused_add_rmsnorm_out_tuple = auto_functionalized(
RMS_ADD_OP,
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=fused_add_rmsnorm_out_tuple[1],
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
)
# quant_out, allreduce_output, output_scale
return (
quant_out_tuple[1],
fused_add_rmsnorm_out_tuple[2],
quant_out_tuple[2],
)
return quant_out_tuple[1], residual, quant_out_tuple[2]
def replacement(
quant_result: torch.Tensor,

View File

@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@ -92,13 +93,19 @@ class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}"
self.QUANT_OP = QUANT_OPS[key.quant]
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(key.quant)
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
@ -112,34 +119,18 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def register(self, pm_pass: PatternMatcherPass):
# Cannot use methods, as the self argument affects tracing
def pattern(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon,
)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale
)
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
result_rms = self.rmsnorm_matcher(input, weight)
return self.quant_matcher(result_rms, scale)[0]
# result
return at2[1]
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
def replacement(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_dtype
)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@ -153,12 +144,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
return at[1]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
# input, weight
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pattern(*inputs)
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
@ -175,33 +165,27 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
at1 = auto_functionalized(
self.QUANT_OP, result=result, input=at[1], scale=scale
)
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, _ = self.quant_matcher(result_rms, scale)
# result, residual
return at1[1], at[2]
return result, residual
def replacement(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@ -216,11 +200,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
return at[1], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
# input, weight, residual
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pm.register_replacement(
@ -248,34 +230,18 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon,
)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
)
def pattern(input: torch.Tensor, weight: torch.Tensor):
result_rms = self.rmsnorm_matcher(input, weight)
# result, scale
return at2[1], at2[2]
return self.quant_matcher(result_rms)
def replacement(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
def replacement(input: torch.Tensor, weight: torch.Tensor):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@ -290,18 +256,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
# result, scale
return at[1], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
@ -323,34 +281,21 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
at1 = auto_functionalized(
self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
)
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
# result, residual, scale
return at1[1], at[2], at1[2]
return result, residual, scale
def replacement(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@ -365,18 +310,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
# result, residual, scale
return at[1], at[3], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
@ -396,23 +333,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
pass_name="rmsnorm_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log

View File

@ -2,9 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
@ -20,7 +22,9 @@ from vllm.platforms import current_platform
from vllm.utils import round_up
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .fx_utils import is_func
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherQuantFP8
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@ -66,9 +70,13 @@ class AttentionQuantPattern(ABC):
return torch.empty(*args, **kwargs)
@staticmethod
def wrap_trace_fn(process_fx, trace_fn):
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped
@ -77,7 +85,20 @@ class AttentionQuantPattern(ABC):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
@staticmethod
def remove_noop_permutes(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if not is_func(node, torch.ops.aten.permute.default):
continue
dims = node.args[1]
if any(dim != i for i, dim in enumerate(dims)):
continue
# this is now an identity op, remove
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
def register_if_supported(self, pm_pass: PatternMatcherPass):
if self.layer.impl.fused_output_quant_supported(self.quant_key):
@ -108,6 +129,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
)
super().__init__(layer, quant_key, dtype)
self.quant_matcher = MatcherQuantFP8(quant_key)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(
@ -115,7 +137,6 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
@ -131,17 +152,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size]
)
at2 = auto_functionalized(
self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale
)
return at2[1]
return self.quant_matcher(attn_out_view, scale)[0]
def replacement(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
scale: torch.Tensor,
):
# attn output in quant_dtype
@ -164,13 +182,10 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
inputs = [
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # q
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # k
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # v
self.empty(
5, self.num_heads, self.head_size, dtype=self.dtype
), # attn_output
self.empty_quant(5, self.num_heads * self.head_size), # quant_output
self.empty(5, self.num_heads, self.head_size), # q
self.empty(5, self.num_heads, self.head_size), # k
self.empty(5, self.num_heads, self.head_size), # v
self.empty(5, self.num_heads, self.head_size), # attn_output
empty_fp32(1, 1), # scale
]
@ -179,7 +194,9 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
@ -279,7 +296,9 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)

View File

@ -6,7 +6,7 @@ from collections.abc import Iterable, Iterator
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._ops import OpOverload
from torch._ops import OpOverload, OpOverloadPacket
def is_func(node: fx.Node, target) -> bool:
@ -64,7 +64,17 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
# An auto-functionalization-aware utility for finding nodes with a specific op
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
# Also handles op overload packets and finds all overloads
def find_op_nodes(
op: OpOverload | OpOverloadPacket, graph: fx.Graph
) -> Iterator[fx.Node]:
if isinstance(op, OpOverloadPacket):
for overload in op.overloads():
overload_op = getattr(op, overload)
yield from find_op_nodes(overload_op, graph)
return
assert isinstance(op, OpOverload)
if not op._schema.is_mutable:
yield from graph.find_nodes(op="call_function", target=op)

View File

@ -0,0 +1,208 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
from torch._higher_order_ops import auto_functionalized
from torch._ops import OpOverload
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
_normalize_quant_group_shape,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.platforms import current_platform
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
class MatcherCustomOp(ABC):
def __init__(self, enabled: bool):
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
self.enabled = enabled
self.forward = self.forward_custom if enabled else self.forward_native
@abstractmethod
def forward_custom(self, *args, **kws):
pass
@abstractmethod
def forward_native(self, *args, **kws):
pass
def __call__(self, *args, **kws):
return self.forward(*args, **kws)
def empty(self, *args, **kws):
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
def empty_f32(self, *args, **kws):
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
def inputs(self) -> list[torch.Tensor]:
"""Utility for inputs to the pattern"""
raise NotImplementedError
class MatcherRMSNorm(MatcherCustomOp):
def __init__(self, epsilon: float, enabled: bool | None = None):
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
def inputs(self):
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
result = torch.empty_like(input)
_, result = auto_functionalized(
RMS_OP,
result=result,
input=input,
weight=weight,
epsilon=self.epsilon,
)
return result
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight
)
class MatcherFusedAddRMSNorm(MatcherCustomOp):
def __init__(self, epsilon: float, enabled: bool | None = None):
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
def inputs(self):
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
residual = self.empty(5, 16)
return [input, weight, residual]
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
_, result, residual = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
return result, residual
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
)
class MatcherQuantFP8(MatcherCustomOp):
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
self.QUANT_OP = QUANT_OPS[quant_key]
assert quant_key.dtype == current_platform.fp8_dtype(), (
"Only QuantFP8 supported by"
)
assert quant_key.scale2 is None
self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)
def forward_custom(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_key.dtype
)
if self.quant_key.scale.static:
assert scale is not None
_, result = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale
)
return result, scale
else:
assert scale is None
scale = self.make_scale(input)
_, result, scale = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
)
return result, scale
def forward_native(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.quant_fp8(input, scale)
def make_scale(self, input: torch.Tensor):
normalized_group_shape = _normalize_quant_group_shape(
input, self.quant_key.scale.group_shape
)
scale_shape = (
input.shape[0] // normalized_group_shape[0],
input.shape[1] // normalized_group_shape[1],
)
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16)
if self.quant_key.scale.static:
return [input, self.empty_f32(1, 1)]
return [input]

View File

@ -22,6 +22,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
import depyf
path.mkdir(parents=True, exist_ok=True)
logger.debug("Dumping depyf output to %s", path)
global context_manager
context_manager = depyf.prepare_debug(path.as_posix())
context_manager.__enter__()

View File

@ -5,7 +5,7 @@ import functools
from torch import fx as fx
from vllm import envs
from vllm.config import VllmConfig
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import set_env_var
@ -88,27 +88,30 @@ class PostGradPassManager(CustomGraphPass):
def configure(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
if self.pass_config.enable_noop:
self.passes += [NoOpEliminationPass(config)]
if self.pass_config.enable_sequence_parallelism:
self.passes += [SequenceParallelismPass(config)]
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
# Set the current vllm config to allow tracing CustomOp instances
with set_current_vllm_config(config, check_compile=False):
if self.pass_config.enable_noop:
self.passes += [NoOpEliminationPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.enable_sequence_parallelism:
self.passes += [SequenceParallelismPass(config)]
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_fusion:
self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_fusion:
self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
# [HACK: Bug with Inductor graph partition and torch.compile cache]
# In PyTorch 2.9, torch.compile has a bug where the graph

View File

@ -128,7 +128,8 @@ class VllmPatternMatcherPass(VllmInductorPass):
f" please add to dump_patterns if there are any errors.\n\n"
f"from torch._higher_order_ops.auto_functionalize import "
f"auto_functionalized as auto_functionalized\n"
f"from torch._inductor.pattern_matcher import *",
f"from torch._inductor.pattern_matcher import *\n"
f"vllm = torch.ops.vllm",
file=f,
)

View File

@ -178,14 +178,11 @@ class RMSNorm(CustomOp):
self.variance_size_override = (
None if var_hidden_size == hidden_size else var_hidden_size
)
weight_dtype = dtype or torch.get_default_dtype()
self.has_weight = has_weight
if dtype is not None:
self.weight = torch.ones(hidden_size, dtype=dtype)
else:
self.weight = torch.ones(hidden_size)
self.weight = torch.ones(hidden_size, dtype=weight_dtype)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
weight_dtype = self.weight.data.dtype
if current_platform.is_rocm():
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
@ -195,46 +192,68 @@ class RMSNorm(CustomOp):
with_fused_add=True, dtype=weight_dtype
)
@staticmethod
def forward_static(
x: torch.Tensor,
variance_epsilon: float,
hidden_size: int,
orig_dtype: torch.dtype,
weight: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
variance_size_override: int | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
x = x.to(torch.float32)
if residual is not None:
# residual promoted f16->f32 automatically,
# otherwise Inductor eliminates the casts to and from f16,
# increasing memory usage (and complicating pattern matching)
x = x + residual
residual = x.to(orig_dtype)
if x.shape[-1] != hidden_size:
raise ValueError(
f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
)
if variance_size_override is None:
x_var = x
else:
if hidden_size < variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{variance_size_override}, but found: {hidden_size}"
)
x_var = x[:, :, :variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype)
if weight is not None:
x = x * weight
if residual is None:
return x
else:
return x, residual
def forward_native(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError(
"Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}"
)
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}"
)
x_var = x[:, :, : self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if residual is None:
return x
else:
return x, residual
return self.forward_static(
x,
self.variance_epsilon,
self.hidden_size,
x.dtype,
self.weight.data if self.has_weight else None,
residual,
self.variance_size_override,
)
def forward_cuda(
self,