[torch.compile] Enable attention and allreduce fusion without custom ops enabled (#24604)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@ -416,8 +416,8 @@ steps:
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s compile/piecewise/
|
||||
|
||||
- label: PyTorch Fullgraph Test # 20min
|
||||
timeout_in_minutes: 30
|
||||
- label: PyTorch Fullgraph Test # 22min
|
||||
timeout_in_minutes: 35
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
@ -425,6 +425,7 @@ steps:
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph.py
|
||||
- pytest -v -s compile/test_fusions_e2e.py
|
||||
|
||||
- label: Kernels Core Operation Test # 48min
|
||||
timeout_in_minutes: 75
|
||||
@ -807,8 +808,8 @@ steps:
|
||||
# Whisper needs spawn method to avoid deadlock
|
||||
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
|
||||
|
||||
- label: Blackwell Test # 38 min
|
||||
timeout_in_minutes: 60
|
||||
- label: Blackwell Test # 21 min
|
||||
timeout_in_minutes: 30
|
||||
working_dir: "/vllm-workspace/"
|
||||
gpu: b200
|
||||
# optional: true
|
||||
@ -821,8 +822,6 @@ steps:
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
|
||||
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/fusion.py
|
||||
- vllm/compilation/fusion_attn.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
@ -839,15 +838,32 @@ steps:
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
|
||||
# Fusion
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
|
||||
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
|
||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
|
||||
- label: Blackwell Fusion Tests # 30 min
|
||||
timeout_in_minutes: 40
|
||||
working_dir: "/vllm-workspace/"
|
||||
gpu: b200
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/fp4/
|
||||
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/
|
||||
# can affect pattern matching
|
||||
- vllm/model_executor/layers/layernorm.py
|
||||
- vllm/model_executor/layers/activation.py
|
||||
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py
|
||||
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
||||
# this runner has 2 GPUs available even though num_gpus=2 is not set
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusions_e2e.py
|
||||
|
||||
- label: Blackwell GPT-OSS Eval
|
||||
timeout_in_minutes: 60
|
||||
@ -1100,7 +1116,7 @@ steps:
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
||||
|
||||
##### H200 test #####
|
||||
- label: Distrubted Tests (H200) # optional
|
||||
- label: Distributed Tests (H200) # optional
|
||||
gpu: h200
|
||||
optional: true
|
||||
working_dir: "/vllm-workspace/"
|
||||
@ -1108,6 +1124,8 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s tests/compile/test_async_tp.py
|
||||
- pytest -v -s tests/compile/test_sequence_parallelism.py
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
|
||||
|
||||
@ -392,6 +392,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
|
||||
TORCH_CHECK(input.scalar_type() == residual.scalar_type());
|
||||
TORCH_CHECK(residual.is_contiguous());
|
||||
TORCH_CHECK(weight.is_contiguous());
|
||||
int hidden_size = input.size(-1);
|
||||
|
||||
@ -229,6 +229,8 @@ void fused_add_rms_norm_static_fp8_quant(
|
||||
double epsilon) {
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(residual.is_contiguous());
|
||||
TORCH_CHECK(residual.scalar_type() == input.scalar_type());
|
||||
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
|
||||
int hidden_size = input.size(-1);
|
||||
int input_stride = input.stride(-2);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
@ -145,7 +145,11 @@ void rms_norm_dynamic_per_token_quant(
|
||||
if (scale_ub.has_value()) {
|
||||
TORCH_CHECK(out.dtype() == kFp8Type);
|
||||
}
|
||||
TORCH_CHECK(weight.dtype() == input.dtype());
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat32);
|
||||
if (residual) {
|
||||
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {
|
||||
|
||||
@ -3,16 +3,22 @@
|
||||
|
||||
import weakref
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
|
||||
import depyf
|
||||
from torch import fx
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx._utils import lazy_format_graph_code
|
||||
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.compilation.pass_manager import with_pattern_match_debug
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger("vllm.tests.compile.backend")
|
||||
|
||||
|
||||
class LazyInitPass(InductorPass):
|
||||
@ -45,20 +51,32 @@ class TestBackend:
|
||||
|
||||
def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]):
|
||||
self.custom_passes = list(passes)
|
||||
compile_config = get_current_vllm_config().compilation_config
|
||||
self.inductor_config = compile_config.inductor_compile_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
compile_config = vllm_config.compilation_config
|
||||
# Deepcopy to allow multiple TestBackend instances to use the same VllmConfig
|
||||
self.inductor_config = deepcopy(compile_config.inductor_compile_config)
|
||||
self.inductor_config["force_disable_caches"] = True
|
||||
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
|
||||
|
||||
if debug_dump_path := vllm_config.compile_debug_dump_path():
|
||||
logger.debug("Dumping depyf output to %s", debug_dump_path)
|
||||
self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix())
|
||||
else:
|
||||
self.debug_ctx = nullcontext()
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs):
|
||||
self.graph_pre_compile = deepcopy(graph)
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
|
||||
with self.debug_ctx:
|
||||
return compile_fx(
|
||||
graph, example_inputs, config_patches=self.inductor_config
|
||||
)
|
||||
|
||||
@with_pattern_match_debug
|
||||
def post_pass(self, graph: fx.Graph):
|
||||
self.graph_pre_pass = deepcopy(graph)
|
||||
lazy_format_graph_code("graph_pre_pass", graph.owning_module)
|
||||
|
||||
VllmInductorPass.dump_prefix = 0
|
||||
for pass_ in self.custom_passes:
|
||||
@ -68,6 +86,7 @@ class TestBackend:
|
||||
VllmInductorPass.dump_prefix = None
|
||||
|
||||
self.graph_post_pass = deepcopy(graph)
|
||||
lazy_format_graph_code("graph_post_pass", graph.owning_module)
|
||||
# assign by reference, will reflect the final state of the graph
|
||||
self.final_graph = graph
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
@ -10,8 +10,6 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
@ -22,23 +20,24 @@ from ..utils import create_new_process_for_each_test
|
||||
def models_list(*, all: bool = True, keywords: list[str] | None = None):
|
||||
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
||||
("facebook/opt-125m", {}),
|
||||
(
|
||||
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
|
||||
{
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
),
|
||||
(
|
||||
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
|
||||
{
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
{"dtype": torch.float16},
|
||||
),
|
||||
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
|
||||
("meta-llama/Llama-3.2-1B-Instruct", {}),
|
||||
]
|
||||
|
||||
if all:
|
||||
TEST_MODELS.extend(
|
||||
[
|
||||
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
|
||||
(
|
||||
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
|
||||
{"dtype": torch.float16},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# TODO: figure out why this fails.
|
||||
if False and is_quant_method_supported("gguf"): # noqa: SIM223
|
||||
TEST_MODELS.append(
|
||||
@ -83,31 +82,38 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None):
|
||||
"compilation_mode",
|
||||
[CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE],
|
||||
)
|
||||
@pytest.mark.parametrize("model_info", models_list(all=True))
|
||||
@pytest.mark.parametrize("model, model_kwargs", models_list(all=True))
|
||||
@create_new_process_for_each_test()
|
||||
def test_full_graph(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
model_info: tuple[str, dict[str, Any]],
|
||||
model: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
compilation_mode: int,
|
||||
):
|
||||
model, model_kwargs = model_info
|
||||
if (
|
||||
"w8a8" in model
|
||||
or "w8w8" in model
|
||||
and current_platform.has_device_capability((10, 0))
|
||||
):
|
||||
# int8 removed on Blackwell:
|
||||
pytest.skip("int8 support removed on Blackwell")
|
||||
|
||||
with monkeypatch.context():
|
||||
print(f"MODEL={model}")
|
||||
|
||||
run_model(compilation_mode, model, model_kwargs)
|
||||
run_model(compilation_mode, model, **model_kwargs)
|
||||
|
||||
|
||||
# TODO(luka) add other supported compilation config scenarios here
|
||||
@pytest.mark.parametrize(
|
||||
"compilation_config, model_info",
|
||||
"compilation_config, model, model_kwargs",
|
||||
[
|
||||
# additional compile sizes, only some of the models
|
||||
(
|
||||
CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]),
|
||||
model,
|
||||
*model_info,
|
||||
)
|
||||
for model in models_list(all=False)
|
||||
for model_info in models_list(all=False)
|
||||
]
|
||||
+ [
|
||||
# RMSNorm + quant fusion, only 8-bit quant models
|
||||
@ -117,18 +123,19 @@ def test_full_graph(
|
||||
custom_ops=["+rms_norm"],
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
),
|
||||
model,
|
||||
*model_info,
|
||||
)
|
||||
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
|
||||
for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
|
||||
]
|
||||
+ [
|
||||
# Test depyf integration works
|
||||
(
|
||||
CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
debug_dump_path=tempfile.gettempdir(),
|
||||
debug_dump_path=Path(tempfile.gettempdir()),
|
||||
),
|
||||
("facebook/opt-125m", {}),
|
||||
"facebook/opt-125m",
|
||||
{},
|
||||
),
|
||||
]
|
||||
+ [
|
||||
@ -142,9 +149,9 @@ def test_full_graph(
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
compile_sizes=[1, 2],
|
||||
),
|
||||
model,
|
||||
*model_info,
|
||||
)
|
||||
for model in models_list(all=False)
|
||||
for model_info in models_list(all=False)
|
||||
if is_torch_equal_or_newer("2.9.0.dev")
|
||||
],
|
||||
)
|
||||
@ -152,16 +159,24 @@ def test_full_graph(
|
||||
@create_new_process_for_each_test()
|
||||
def test_custom_compile_config(
|
||||
compilation_config: CompilationConfig,
|
||||
model_info: tuple[str, dict[str, Any]],
|
||||
model: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
):
|
||||
if (
|
||||
"w8a8" in model
|
||||
or "w8w8" in model
|
||||
and current_platform.has_device_capability((10, 0))
|
||||
):
|
||||
# int8 removed on Blackwell:
|
||||
pytest.skip("int8 support removed on Blackwell")
|
||||
|
||||
if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
|
||||
"2.9.0.dev"
|
||||
):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
model, model_kwargs = model_info
|
||||
print(f"MODEL={model}")
|
||||
run_model(compilation_config, model, model_kwargs)
|
||||
run_model(compilation_config, model, **model_kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -176,50 +191,16 @@ def test_fp8_kv_scale_compile(compilation_mode: int):
|
||||
"calculate_kv_scales": True,
|
||||
"max_model_len": 512,
|
||||
}
|
||||
run_model(compilation_mode, model, model_kwargs)
|
||||
run_model(compilation_mode, model, **model_kwargs)
|
||||
|
||||
|
||||
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
compilation_config = CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
custom_ops=["+quant_fp8"],
|
||||
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
|
||||
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
|
||||
compilation_config = (
|
||||
compile_config
|
||||
if isinstance(compile_config, CompilationConfig)
|
||||
else CompilationConfig(level=compile_config)
|
||||
)
|
||||
model_kwargs = {
|
||||
"kv_cache_dtype": "fp8",
|
||||
"max_model_len": 1024,
|
||||
}
|
||||
with (
|
||||
caplog_vllm.at_level(logging.DEBUG),
|
||||
global_force_attn_backend_context_manager(_Backend.FLASHINFER),
|
||||
):
|
||||
run_model(compilation_config, model, model_kwargs)
|
||||
|
||||
try:
|
||||
assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, (
|
||||
caplog_vllm.text
|
||||
)
|
||||
except AssertionError:
|
||||
# Note: this message is only triggered when the compilation goes
|
||||
# through the custom pass. Due to multiple layers of cache on
|
||||
# PyTorch side, the compilation of a graph may be cached such
|
||||
# that custom pass directly goes through cache. In this case,
|
||||
# we go through this branch and assert that the pass is not
|
||||
# triggered.
|
||||
assert "Fused quantization" not in caplog_vllm.text
|
||||
|
||||
|
||||
def run_model(
|
||||
compile_config: int | CompilationConfig,
|
||||
model: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
):
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@ -227,12 +208,17 @@ def run_model(
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
# Allow override from model_kwargs
|
||||
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
|
||||
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
|
||||
|
||||
# No cudagraphs by default
|
||||
if compilation_config.cudagraph_mode is None:
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1,
|
||||
disable_custom_all_reduce=True,
|
||||
compilation_config=compile_config,
|
||||
compilation_config=compilation_config,
|
||||
**model_kwargs,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
@ -11,7 +11,13 @@ from vllm.compilation.fusion import RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
@ -48,8 +54,7 @@ class TestSiluMul(torch.nn.Module):
|
||||
return y
|
||||
|
||||
def example_inputs(self, num_tokens=32, hidden_size=128):
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
|
||||
return (torch.rand(num_tokens, hidden_size * 2),)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
if TEST_FP8 and do_fusion:
|
||||
@ -67,15 +72,11 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size), dtype=dtype)
|
||||
torch.empty((intermediate_size, hidden_size))
|
||||
)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
self.norm.weight = torch.nn.Parameter(
|
||||
torch.ones(intermediate_size, dtype=dtype)
|
||||
)
|
||||
self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
|
||||
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
@ -112,9 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
return norm_output, residual_output
|
||||
|
||||
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size))
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size))
|
||||
return (hidden_states, residual)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
@ -145,10 +145,9 @@ class TestRotaryEmbedding(torch.nn.Module):
|
||||
return q_rotated, k_rotated
|
||||
|
||||
def example_inputs(self, num_tokens=32, head_dim=64):
|
||||
dtype = torch.float16
|
||||
positions = torch.arange(num_tokens, dtype=torch.long)
|
||||
q = torch.randn(num_tokens, head_dim, dtype=dtype)
|
||||
k = torch.randn(num_tokens, head_dim, dtype=dtype)
|
||||
q = torch.randn(num_tokens, head_dim)
|
||||
k = torch.randn(num_tokens, head_dim)
|
||||
return (positions, q, k)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
@ -166,7 +165,7 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
self.hidden_size = head_dim * num_heads
|
||||
|
||||
self.qkv_proj = torch.nn.Linear(
|
||||
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
|
||||
self.hidden_size, self.hidden_size * 3, bias=False
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@ -190,10 +189,9 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
return qkv_updated
|
||||
|
||||
def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
|
||||
dtype = torch.float16
|
||||
hidden_size = head_dim * num_heads
|
||||
positions = torch.arange(num_tokens, dtype=torch.long)
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
return (positions, hidden_states)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
@ -211,48 +209,58 @@ MODELS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("model_class", MODELS)
|
||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
|
||||
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
||||
def test_fix_functionalization(
|
||||
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
custom_ops=["all"],
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
|
||||
),
|
||||
)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||
|
||||
passes = (
|
||||
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||
if do_fusion
|
||||
else [noop_pass, cleanup_pass]
|
||||
)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
assert RMSNorm.enabled()
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
passes = (
|
||||
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||
if do_fusion
|
||||
else [noop_pass, cleanup_pass]
|
||||
)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
|
||||
model = model_class()
|
||||
torch.compile(model, backend=backend_func)(*model.example_inputs())
|
||||
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
model = model_class()
|
||||
torch.compile(model, backend=backend_func)(*model.example_inputs())
|
||||
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
for op in model.ops_not_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model(do_fusion))
|
||||
assert all(not found.get(op) for op in model.ops_not_in_model())
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
for op in model.ops_not_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model(do_fusion))
|
||||
assert all(not found.get(op) for op in model.ops_not_in_model())
|
||||
|
||||
@ -5,15 +5,18 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.plugins
|
||||
from vllm.compilation.fusion import (
|
||||
FUSED_OPS,
|
||||
QUANT_OPS,
|
||||
FusedRMSQuantKey,
|
||||
RMSNormQuantFusionPass,
|
||||
)
|
||||
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
@ -32,6 +35,9 @@ from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
@ -45,18 +51,18 @@ class TestModel(torch.nn.Module):
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cuda_force_torch = cuda_force_torch
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
||||
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||
if static:
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
else:
|
||||
self.scale = [None for _ in range(2)]
|
||||
self.scale = [None for _ in range(3)]
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
for _ in range(2)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
@ -65,8 +71,12 @@ class TestModel(torch.nn.Module):
|
||||
act_quant_group_shape=group_shape,
|
||||
)
|
||||
|
||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||
|
||||
def forward(self, x):
|
||||
resid = torch.sqrt(x)
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
x = resid = torch.relu(x)
|
||||
y = self.norm[0](x)
|
||||
|
||||
x2 = self.fp8_linear.apply(
|
||||
@ -78,24 +88,44 @@ class TestModel(torch.nn.Module):
|
||||
x3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
return y3
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [QUANT_OPS[self.key]]
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
x4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
)
|
||||
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.key, True)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
|
||||
]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return (
|
||||
[QUANT_OPS[self.quant_key]]
|
||||
if self.enable_quant_fp8_custom_op
|
||||
else [torch.ops.aten.reciprocal]
|
||||
)
|
||||
|
||||
def ops_in_model_before_partial(self):
|
||||
return (
|
||||
[RMS_OP, RMS_ADD_OP]
|
||||
if self.enable_rms_norm_custom_op
|
||||
else [torch.ops.aten.rsqrt]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("num_tokens", [257])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("static", [True, False])
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
# cutlass_fp8_supported() == True.
|
||||
@pytest.mark.parametrize(
|
||||
@ -105,19 +135,32 @@ class TestModel(torch.nn.Module):
|
||||
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
|
||||
)
|
||||
def test_fusion_rmsnorm_quant(
|
||||
dtype, hidden_size, num_tokens, eps, static, cuda_force_torch
|
||||
dtype,
|
||||
hidden_size,
|
||||
num_tokens,
|
||||
eps,
|
||||
static,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
cuda_force_torch,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||
|
||||
custom_ops = []
|
||||
if enable_rms_norm_custom_op:
|
||||
custom_ops.append("+rms_norm")
|
||||
if enable_quant_fp8_custom_op:
|
||||
custom_ops.append("+quant_fp8")
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+rms_norm", "+quant_fp8"],
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
)
|
||||
),
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
@ -126,31 +169,39 @@ def test_fusion_rmsnorm_quant(
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||
backend2 = TestBackend(noop_pass, cleanup_pass)
|
||||
model = TestModel(hidden_size, eps, static, cuda_force_torch)
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
result = model(x)
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
result_fused = model_fused(x)
|
||||
|
||||
model2 = torch.compile(model, backend=backend)
|
||||
result2 = model2(x)
|
||||
model_unfused = torch.compile(model, backend=backend2)
|
||||
result_unfused = model_unfused(x)
|
||||
|
||||
# Higher tol for dynamic, even higher for bfloat16
|
||||
if static:
|
||||
ATOL, RTOL = (1e-3, 1e-3)
|
||||
elif dtype == torch.float16:
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
|
||||
|
||||
assert fusion_pass.matched_count == 2
|
||||
|
||||
# In pre-nodes, fp8 quant should be there and fused kernels should not
|
||||
assert fusion_pass.matched_count == 3
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
# In post-nodes, fused kernels should be there and fp8 quant should not
|
||||
backend.check_before_ops(
|
||||
model.ops_in_model_before_partial(), fully_replaced=False
|
||||
)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
# If RMSNorm custom op is disabled (native/torch impl used),
|
||||
# there's a risk that the fused add doesn't get included in the
|
||||
# replacement and only the rms part gets fused with quant.
|
||||
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
||||
if not enable_rms_norm_custom_op:
|
||||
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
||||
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
|
||||
assert n_add_nodes(backend.graph_pre_pass) == 7
|
||||
assert n_add_nodes(backend.graph_post_pass) == 2
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
@ -17,6 +18,7 @@ from vllm.config import (
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
@ -25,8 +27,8 @@ from vllm.distributed.parallel_state import (
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
GroupShape,
|
||||
QuantFP8,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
@ -40,13 +42,30 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm = self.norm(all_reduce)
|
||||
return norm
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(x)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = torch.mm(y, self.w[0])
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = torch.mm(y2, self.w[1])
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
|
||||
y3, resid = self.norm[2](x3, resid)
|
||||
|
||||
z4 = torch.mm(y3, self.w[2])
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
|
||||
y4, resid = self.norm[3](x4, resid)
|
||||
return y4
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
@ -55,44 +74,53 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size)
|
||||
.to(dtype=current_platform.fp8_dtype())
|
||||
.t()
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm, _ = self.norm(all_reduce, residual)
|
||||
return norm
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
torch.ops._C.static_scaled_fp8_quant(
|
||||
self.output, norm_output.contiguous(), self.scale
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
return self.output, residual_output
|
||||
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(hidden_states)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
z4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
)
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
@ -100,7 +128,9 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default
|
||||
if self.fp8_linear.quant_fp8.enabled()
|
||||
else torch.ops.aten.reciprocal.default,
|
||||
]
|
||||
|
||||
|
||||
@ -109,25 +139,48 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(token_num, 128)
|
||||
scale_n = hidden_size // 16
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32)
|
||||
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
|
||||
self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)]
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
|
||||
torch.ops._C.scaled_fp4_quant(
|
||||
self.output, norm_output, self.output_scale, self.scale
|
||||
wq_gen, wscale_gen = zip(
|
||||
*(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
|
||||
)
|
||||
return self.output, residual_output, self.output_scale
|
||||
self.wq, self.wscale = list(wq_gen), list(wscale_gen)
|
||||
print(f"{self.wq=}, {self.wscale=}")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(hidden_states)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
yq, y_scale = scaled_fp4_quant(y, self.agscale[0])
|
||||
z2 = cutlass_scaled_fp4_mm(
|
||||
yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype
|
||||
)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1])
|
||||
z3 = cutlass_scaled_fp4_mm(
|
||||
yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype
|
||||
)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2])
|
||||
z4 = cutlass_scaled_fp4_mm(
|
||||
yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype
|
||||
)
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
@ -141,19 +194,19 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"test_model",
|
||||
"test_model, enable_quant_fp8_custom_op",
|
||||
[
|
||||
TestAllReduceRMSNormModel,
|
||||
TestAllReduceFusedAddRMSNormModel,
|
||||
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
|
||||
# TODO: Enable with torch==2.8.0
|
||||
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
|
||||
(TestAllReduceRMSNormModel, False),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, True),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, False),
|
||||
(TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [8])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(
|
||||
not find_spec("flashinfer")
|
||||
@ -167,6 +220,8 @@ def test_all_reduce_fusion_pass_replace(
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
):
|
||||
num_processes = 2
|
||||
if (
|
||||
@ -181,7 +236,16 @@ def test_all_reduce_fusion_pass_replace(
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
|
||||
args=(
|
||||
num_processes,
|
||||
test_model,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
@ -196,6 +260,8 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
@ -217,15 +283,22 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
custom_ops = []
|
||||
if enable_rms_norm_custom_op:
|
||||
custom_ops.append("+rms_norm")
|
||||
if enable_quant_fp8_custom_op:
|
||||
custom_ops.append("+quant_fp8")
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm", "+quant_fp8"]
|
||||
mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops
|
||||
)
|
||||
)
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
enable_fi_allreduce_fusion=True, enable_noop=True
|
||||
)
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
@ -233,24 +306,27 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
vllm_config.model_config = ModelConfig(
|
||||
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
backend = TestBackend(
|
||||
noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass
|
||||
)
|
||||
|
||||
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass)
|
||||
token_num = batch_size * seq_len
|
||||
model = test_model_cls(hidden_size, token_num)
|
||||
|
||||
token_num = batch_size * seq_len
|
||||
model = test_model_cls(hidden_size, token_num)
|
||||
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||
|
||||
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||
residual = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states, residual)
|
||||
|
||||
assert all_reduce_fusion_pass.matched_count == 1
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
del all_reduce_fusion_pass
|
||||
assert all_reduce_fusion_pass.matched_count == 4, (
|
||||
f"{all_reduce_fusion_pass.matched_count=}"
|
||||
)
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
del all_reduce_fusion_pass
|
||||
|
||||
@ -6,14 +6,15 @@ import pytest
|
||||
import torch._dynamo
|
||||
|
||||
from tests.compile.backend import LazyInitPass, TestBackend
|
||||
from tests.utils import flat_product
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.compilation.fusion import QUANT_OPS
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
@ -28,21 +29,18 @@ from vllm.config import (
|
||||
)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
# globals needed for string-import custom Dynamo backend field
|
||||
backend: TestBackend | None = None
|
||||
backend_unfused: TestBackend | None = None
|
||||
|
||||
|
||||
class AttentionQuantPatternModel(torch.nn.Module):
|
||||
"""Base model for AttentionQuantPattern fusion."""
|
||||
@ -104,6 +102,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
num_blocks = batch_size * max_blocks
|
||||
backend = self.attn.backend
|
||||
|
||||
# TODO(luka) use get_kv_cache_stride_order
|
||||
# Create dummy KV cache for the selected backend
|
||||
if backend == _Backend.ROCM_ATTN:
|
||||
# k/v as 1st dimention
|
||||
@ -241,26 +240,40 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||
)
|
||||
|
||||
|
||||
MODELS_FP8: list[tuple[str, type]] = []
|
||||
MODELS_FP4: list[tuple[str, type]] = []
|
||||
HEADS: list[tuple[int, int]] = []
|
||||
SPLIT_ATTENTION: list[bool] = []
|
||||
BACKENDS_FP8: list[_Backend] = []
|
||||
BACKENDS_FP4: list[_Backend] = []
|
||||
|
||||
if current_platform.is_cuda():
|
||||
MODELS = [
|
||||
HEADS = [(64, 8), (40, 8)]
|
||||
MODELS_FP8 = [
|
||||
(
|
||||
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
TestAttentionFp8StaticQuantPatternModel,
|
||||
),
|
||||
)
|
||||
]
|
||||
MODELS_FP4 = [
|
||||
(
|
||||
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||
TestAttentionNvfp4QuantPatternModel,
|
||||
),
|
||||
)
|
||||
]
|
||||
HEADS = [(64, 8), (40, 8)]
|
||||
BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER]
|
||||
BACKENDS_FP4 = [_Backend.FLASHINFER]
|
||||
|
||||
elif current_platform.is_rocm():
|
||||
MODELS = [
|
||||
HEADS = [(32, 8), (40, 8)]
|
||||
MODELS_FP8 = [
|
||||
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
||||
]
|
||||
HEADS = [(32, 8), (40, 8)]
|
||||
else:
|
||||
MODELS = []
|
||||
HEADS = []
|
||||
BACKENDS = [
|
||||
_Backend.ROCM_AITER_UNIFIED_ATTN,
|
||||
_Backend.ROCM_ATTN,
|
||||
_Backend.TRITON_ATTN,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
|
||||
@ -269,46 +282,36 @@ else:
|
||||
"batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
[_Backend.FLASHINFER]
|
||||
if current_platform.is_cuda()
|
||||
else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN],
|
||||
)
|
||||
# TODO(boyuan): test inductor graph partition on rocm
|
||||
@pytest.mark.parametrize(
|
||||
"use_inductor_graph_partition",
|
||||
[False] if current_platform.is_rocm() else [False, True],
|
||||
"backend, model_name, model_class, custom_ops",
|
||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||
list(flat_product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"]))
|
||||
# quant_fp4 only has the custom impl
|
||||
+ list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])),
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
|
||||
)
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)),
|
||||
reason="On CUDA only test on SM100(Blackwell)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
|
||||
)
|
||||
def test_attention_quant_pattern(
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
batch_size: int,
|
||||
dtype: torch.dtype,
|
||||
custom_ops: str,
|
||||
model_name: str,
|
||||
model_class: type[AttentionQuantPatternModel],
|
||||
backend: _Backend,
|
||||
use_inductor_graph_partition: bool,
|
||||
dist_init,
|
||||
caplog_vllm,
|
||||
):
|
||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||
if backend == _Backend.FLASHINFER and (
|
||||
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
|
||||
):
|
||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
torch.manual_seed(42)
|
||||
@ -322,8 +325,7 @@ def test_attention_quant_pattern(
|
||||
scheduler_config=SchedulerConfig(max_num_seqs=1024),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+quant_fp8"],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
custom_ops=custom_ops_list,
|
||||
),
|
||||
cache_config=CacheConfig(cache_dtype="fp8"),
|
||||
)
|
||||
@ -358,8 +360,9 @@ def test_attention_quant_pattern(
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
|
||||
|
||||
# Run model directly without compilation and fusion
|
||||
result_unfused = model_unfused(q, k, v)
|
||||
# Run model directly without fusion
|
||||
# Still compile so query QuantFP8 has closer numerics
|
||||
result_unfused = torch.compile(model_unfused, fullgraph=True)(q, k, v)
|
||||
|
||||
# Run model with attn fusion enabled
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
@ -414,16 +417,25 @@ def test_attention_quant_pattern(
|
||||
)
|
||||
|
||||
# Check attn fusion support
|
||||
quant_key = model_class.quant_key
|
||||
quant_key: QuantKey = model_class.quant_key
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key)
|
||||
for key, layer in vllm_config.compilation_config.static_forward_context.items()
|
||||
]
|
||||
if any(attn_fusion_supported):
|
||||
# Check quantization ops in the graph before and after fusion
|
||||
# Note: fully_replaced=False because query quant ops remain in graph.
|
||||
# Only output quant ops are fused into attention.
|
||||
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
|
||||
assert sum(attn_fusion_supported) == len(attn_fusion_supported), (
|
||||
"All layers should support attention fusion"
|
||||
)
|
||||
|
||||
# Check quantization ops in the graph before and after fusion
|
||||
quant_op = (
|
||||
torch.ops.aten.reciprocal
|
||||
if "-quant_fp8" in custom_ops_list
|
||||
else QUANT_OPS[quant_key]
|
||||
)
|
||||
|
||||
# Note: for fp8, fully_replaced=False because query quant ops remain in graph.
|
||||
# Only output quant ops are fused into attention.
|
||||
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant)
|
||||
|
||||
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
||||
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
||||
|
||||
305
tests/compile/test_fusions_e2e.py
Normal file
305
tests/compile/test_fusions_e2e.py
Normal file
@ -0,0 +1,305 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import pytest
|
||||
import regex as re
|
||||
|
||||
from tests.v1.attention.utils import _Backend
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
from ..utils import flat_product, multi_gpu_test
|
||||
|
||||
|
||||
class ModelBackendTestCase(NamedTuple):
|
||||
model_name: str
|
||||
model_kwargs: dict[str, Any]
|
||||
backend: _Backend
|
||||
attention_fusions: int
|
||||
allreduce_fusions: int | None = None
|
||||
|
||||
|
||||
MODELS_FP8: list[ModelBackendTestCase] = []
|
||||
MODELS_FP4: list[ModelBackendTestCase] = []
|
||||
MODELS: list[ModelBackendTestCase] = [] # tp-only
|
||||
|
||||
if current_platform.is_cuda():
|
||||
MODELS_FP8 = [
|
||||
ModelBackendTestCase(
|
||||
# Use smaller model for L40s in CI
|
||||
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=_Backend.FLASHINFER,
|
||||
attention_fusions=48,
|
||||
allreduce_fusions=96,
|
||||
),
|
||||
]
|
||||
|
||||
MODELS_FP4 = [
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=_Backend.FLASHINFER,
|
||||
attention_fusions=48,
|
||||
allreduce_fusions=96,
|
||||
),
|
||||
]
|
||||
|
||||
# TP only
|
||||
MODELS = [
|
||||
ModelBackendTestCase(
|
||||
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
attention_fusions=0,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
]
|
||||
|
||||
elif current_platform.is_rocm():
|
||||
MODELS_FP8 = [
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.ROCM_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
]
|
||||
|
||||
# TODO(luka) test both in nightly
|
||||
CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, "
|
||||
"attention_fusions, allreduce_fusions, custom_ops",
|
||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
|
||||
# quant_fp4 only has the custom impl
|
||||
+ list(flat_product(MODELS_FP4, [""])),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
def test_attn_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
backend: _Backend,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if backend == _Backend.FLASHINFER and (
|
||||
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
|
||||
):
|
||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
splitting_ops: list[str] | None = None
|
||||
else:
|
||||
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
splitting_ops = []
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
custom_ops=custom_ops_list,
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
cudagraph_mode=mode,
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(compilation_config, model_name, **model_kwargs)
|
||||
|
||||
matches = re.findall(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(matches) == 1, log_holder.text
|
||||
assert int(matches[0]) == attention_fusions
|
||||
|
||||
|
||||
# TODO(luka) test both in nightly
|
||||
CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"]
|
||||
|
||||
|
||||
def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
|
||||
for op_list in itertools.product(*custom_ops_lists):
|
||||
yield ",".join(op_list)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, "
|
||||
"attention_fusions, allreduce_fusions, custom_ops",
|
||||
# Toggle RMSNorm and QuantFP8 for FP8 models
|
||||
list(
|
||||
flat_product(
|
||||
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
|
||||
)
|
||||
)
|
||||
# Toggle RMSNorm for FP4 models and unquant models
|
||||
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda()
|
||||
or not has_flashinfer()
|
||||
or not current_platform.has_device_capability(90),
|
||||
reason="allreduce+rmsnorm fusion requires flashinfer",
|
||||
)
|
||||
def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
model_name: str,
|
||||
model_kwargs: dict,
|
||||
backend: _Backend,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
splitting_ops: list[str] | None = None
|
||||
else:
|
||||
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
splitting_ops = []
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
cudagraph_mode=mode,
|
||||
custom_ops=custom_ops_list,
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(
|
||||
enable_attn_fusion=True,
|
||||
enable_noop=True,
|
||||
enable_fi_allreduce_fusion=True,
|
||||
),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(
|
||||
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
|
||||
)
|
||||
matches = re.findall(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(matches) == 2, log_holder.text
|
||||
|
||||
assert int(matches[0]) == attention_fusions
|
||||
assert int(matches[1]) == attention_fusions
|
||||
|
||||
matches = re.findall(
|
||||
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(matches) == 2, log_holder.text
|
||||
|
||||
assert int(matches[0]) == allreduce_fusions
|
||||
assert int(matches[1]) == allreduce_fusions
|
||||
|
||||
|
||||
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
|
||||
compilation_config = (
|
||||
compile_config
|
||||
if isinstance(compile_config, CompilationConfig)
|
||||
else CompilationConfig(level=compile_config)
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
# Allow override from model_kwargs
|
||||
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
|
||||
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
|
||||
|
||||
# No cudagraphs by default
|
||||
if compilation_config.cudagraph_mode is None:
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
compilation_config=compilation_config,
|
||||
**model_kwargs,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.compilation.pass_manager import PostGradPassManager
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
|
||||
# dummy custom pass that doesn't inherit
|
||||
@ -42,7 +42,8 @@ class ProperPass(InductorPass):
|
||||
],
|
||||
)
|
||||
def test_pass_manager_uuid(callable):
|
||||
config = VllmConfig()
|
||||
# Some passes need dtype to be set
|
||||
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
|
||||
|
||||
pass_manager = PostGradPassManager()
|
||||
pass_manager.configure(config)
|
||||
|
||||
@ -18,6 +18,8 @@ from vllm.config import (
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
@ -42,9 +44,7 @@ prompts = [
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
|
||||
):
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
@ -95,13 +95,11 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestQuantModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
|
||||
):
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.vllm_config = vllm_config
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size)), requires_grad=False
|
||||
)
|
||||
@ -266,76 +264,84 @@ def sequence_parallelism_pass_on_test_model(
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(
|
||||
enable_sequence_parallelism=True,
|
||||
enable_fusion=enable_fusion,
|
||||
enable_noop=True,
|
||||
)
|
||||
) # NoOp needed for fusion
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
|
||||
vllm_config.model_config = ModelConfig(
|
||||
model_config = ModelConfig(
|
||||
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||
)
|
||||
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
device_config=device_config,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass]
|
||||
with set_current_vllm_config(vllm_config):
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
passes_for_backend: list[VllmInductorPass] = [
|
||||
noop_pass,
|
||||
sequence_parallelism_pass,
|
||||
]
|
||||
|
||||
if enable_fusion:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
passes_for_backend.append(fusion_pass)
|
||||
if enable_fusion:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
passes_for_backend.append(fusion_pass)
|
||||
|
||||
passes_for_backend.append(cleanup_pass)
|
||||
passes_for_backend.append(cleanup_pass)
|
||||
|
||||
backend_no_func = TestBackend(*passes_for_backend)
|
||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||
backend_no_func = TestBackend(*passes_for_backend)
|
||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||
|
||||
model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config)
|
||||
model = test_model_cls(hidden_size, hidden_size * 2)
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
|
||||
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
||||
compiled_model_no_func(hidden_states, residual)
|
||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||
compiled_model_func(hidden_states, residual)
|
||||
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
||||
compiled_model_no_func(hidden_states, residual)
|
||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||
compiled_model_func(hidden_states, residual)
|
||||
|
||||
assert sequence_parallelism_pass.matched_count == 1
|
||||
assert sequence_parallelism_pass.matched_count == 1
|
||||
|
||||
# In pre-nodes, all reduce should be there,
|
||||
# reduce scatter and all gather should not
|
||||
backend_no_func.check_before_ops(model.ops_in_model_before())
|
||||
# In pre-nodes, all reduce should be there,
|
||||
# reduce scatter and all gather should not
|
||||
backend_no_func.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
# In post-nodes, reduce scatter and all gather should be there,
|
||||
# all reduce should not
|
||||
backend_no_func.check_after_ops(model.ops_in_model_after())
|
||||
# In post-nodes, reduce scatter and all gather should be there,
|
||||
# all reduce should not
|
||||
backend_no_func.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model():
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model())
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
for op in model.ops_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model())
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ruff: noqa
|
||||
import contextlib
|
||||
import pathlib
|
||||
from copy import deepcopy
|
||||
|
||||
from tblib import pickling_support
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
# Install support for pickling exceptions so that we can nicely propagate
|
||||
# failures from tests running in a subprocess.
|
||||
# This should be run before any custom exception subclasses are defined.
|
||||
@ -40,7 +43,7 @@ from transformers import (
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm import LLM, SamplingParams, envs
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
@ -1070,6 +1073,101 @@ def caplog_vllm(temporary_enable_log_propagate, caplog):
|
||||
yield caplog
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def caplog_mp_fork():
|
||||
"""
|
||||
This fixture enables capturing logs from a forked MP subprocess.
|
||||
It should be used in conjunction with caplog_vllm.
|
||||
|
||||
By default, subprocess logs do not go through the parent process.
|
||||
We instead create a queue listener in the parent process which
|
||||
forwards logs to the logger's other handlers, and add a QueueHandler
|
||||
to the root logger. Forked subprocesses will inherit the root logger
|
||||
and pass their messages to the queue, which the listener will forward
|
||||
to the root logger, which can be captured by caplog.
|
||||
|
||||
Note that this workaround only works for fork; with spawn, the subprocess
|
||||
reinitializes logging and does not automatically inherit the queue.
|
||||
We'd have to manually pass the queue to the subprocess at the spawn point.
|
||||
See caplog_mp_spawn below.
|
||||
"""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ctx():
|
||||
import logging.handlers
|
||||
import multiprocessing as mp
|
||||
|
||||
logger_queue: mp.Queue[logging.LogRecord] = mp.Queue()
|
||||
logger = logging.getLogger()
|
||||
handlers = logger.handlers
|
||||
|
||||
# The listener works on a background thread, not inherited by the child.
|
||||
queue_listener = logging.handlers.QueueListener(logger_queue, *handlers)
|
||||
queue_listener.start()
|
||||
|
||||
# Add queue handler after creating the listener to avoid cycle
|
||||
logger.addHandler(logging.handlers.QueueHandler(logger_queue))
|
||||
yield
|
||||
queue_listener.stop()
|
||||
|
||||
return ctx
|
||||
|
||||
|
||||
class LogHolder:
|
||||
def __init__(self):
|
||||
self.text = None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def caplog_mp_spawn(tmp_path, monkeypatch):
|
||||
"""
|
||||
This fixture enables capturing logs from a forked MP subprocess.
|
||||
It does not require caplog_vllm (but it only contains logs from the child).
|
||||
|
||||
By default, subprocess logs do not go through the parent process.
|
||||
We instead add a FileHandler to the config so the spawned child process
|
||||
writes its logs to a temp file.
|
||||
In the parent, we read the file and return the contents.
|
||||
|
||||
Note: this method could be extended to fork by either reconfiguring logging
|
||||
in the parent or using a SocketHandler:
|
||||
https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network # noqa: E501
|
||||
"""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ctx(level: int | str):
|
||||
from vllm.logger import DEFAULT_LOGGING_CONFIG
|
||||
|
||||
config_path = tmp_path / "vllm_logging_config.json"
|
||||
log_path = tmp_path / "vllm.log"
|
||||
log_holder = LogHolder()
|
||||
|
||||
config = deepcopy(DEFAULT_LOGGING_CONFIG)
|
||||
if envs.VLLM_LOGGING_CONFIG_PATH:
|
||||
path = pathlib.Path(envs.VLLM_LOGGING_CONFIG_PATH)
|
||||
assert path.exists()
|
||||
config = json.loads(path.read_text())
|
||||
|
||||
config["loggers"]["vllm"]["handlers"] += ["vllm_file"]
|
||||
config["handlers"]["vllm_file"] = {
|
||||
"class": "logging.FileHandler",
|
||||
"formatter": "vllm",
|
||||
"level": level,
|
||||
"filename": log_path.as_posix(),
|
||||
}
|
||||
|
||||
config_path.write_text(json.dumps(config))
|
||||
|
||||
with monkeypatch.context() as monkeypatch_ctx:
|
||||
monkeypatch_ctx.setenv("VLLM_LOGGING_CONFIG_PATH", config_path.as_posix())
|
||||
monkeypatch_ctx.setenv("VLLM_CONFIGURE_LOGGING", "1")
|
||||
yield log_holder
|
||||
|
||||
log_holder.text = log_path.read_text()
|
||||
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def num_gpus_available():
|
||||
"""Get number of GPUs without initializing the CUDA context
|
||||
|
||||
@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant(
|
||||
.clamp(fp8_traits_min, fp8_traits_max)
|
||||
.to(FP8_DTYPE)
|
||||
)
|
||||
return ref_out, ref_scale.view((1,))
|
||||
return ref_out, ref_scale.view((1, 1))
|
||||
|
||||
|
||||
def native_w8a8_block_matmul(
|
||||
|
||||
@ -501,3 +501,49 @@ def test_streaming_complete_logs_full_text_content():
|
||||
assert call_args[1] == "test-streaming-full-text"
|
||||
assert call_args[2] == " (streaming complete)"
|
||||
assert call_args[5] == "streaming_complete"
|
||||
|
||||
|
||||
# Add vllm prefix to make sure logs go through the vllm logger
|
||||
test_logger = init_logger("vllm.test_logger")
|
||||
|
||||
|
||||
def mp_function(**kwargs):
|
||||
# This function runs in a subprocess
|
||||
|
||||
test_logger.warning("This is a subprocess: %s", kwargs.get("a"))
|
||||
test_logger.error("This is a subprocess error.")
|
||||
test_logger.debug("This is a subprocess debug message: %s.", kwargs.get("b"))
|
||||
|
||||
|
||||
def test_caplog_mp_fork(caplog_vllm, caplog_mp_fork):
|
||||
with caplog_vllm.at_level(logging.DEBUG), caplog_mp_fork():
|
||||
import multiprocessing
|
||||
|
||||
ctx = multiprocessing.get_context("fork")
|
||||
p = ctx.Process(
|
||||
target=mp_function,
|
||||
name=f"SubProcess{1}",
|
||||
kwargs={"a": "AAAA", "b": "BBBBB"},
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
assert "AAAA" in caplog_vllm.text
|
||||
assert "BBBBB" in caplog_vllm.text
|
||||
|
||||
|
||||
def test_caplog_mp_spawn(caplog_mp_spawn):
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
import multiprocessing
|
||||
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
p = ctx.Process(
|
||||
target=mp_function,
|
||||
name=f"SubProcess{1}",
|
||||
kwargs={"a": "AAAA", "b": "BBBBB"},
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
assert "AAAA" in log_holder.text
|
||||
assert "BBBBB" in log_holder.text
|
||||
|
||||
@ -6,6 +6,7 @@ import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import importlib
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
@ -15,7 +16,7 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterable
|
||||
from contextlib import ExitStack, contextmanager, suppress
|
||||
from multiprocessing import Process
|
||||
from pathlib import Path
|
||||
@ -1261,3 +1262,23 @@ def check_answers(
|
||||
frac_ok = numok / len(answer)
|
||||
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
|
||||
assert frac_ok >= accept_rate
|
||||
|
||||
|
||||
def flat_product(*iterables: Iterable[Any]):
|
||||
"""
|
||||
Flatten lists of tuples of the cartesian product.
|
||||
Useful when we want to avoid nested tuples to allow
|
||||
test params to be unpacked directly from the decorator.
|
||||
|
||||
Example:
|
||||
flat_product([(1, 2), (3, 4)], ["a", "b"]) ->
|
||||
[
|
||||
(1, 2, "a"),
|
||||
(1, 2, "b"),
|
||||
(3, 4, "a"),
|
||||
(3, 4, "b"),
|
||||
]
|
||||
"""
|
||||
for element in itertools.product(*iterables):
|
||||
normalized = (e if isinstance(e, tuple) else (e,) for e in element)
|
||||
yield tuple(itertools.chain(*normalized))
|
||||
|
||||
@ -40,7 +40,7 @@ from vllm.utils import (
|
||||
unique_filepath,
|
||||
)
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
from ..utils import create_new_process_for_each_test, flat_product
|
||||
|
||||
|
||||
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
|
||||
@ -771,3 +771,25 @@ def test_unique_filepath():
|
||||
paths.add(path)
|
||||
assert len(paths) == 10
|
||||
assert len(list(Path(temp_dir).glob("*.txt"))) == 10
|
||||
|
||||
|
||||
def test_flat_product():
|
||||
# Check regular itertools.product behavior
|
||||
result1 = list(flat_product([1, 2, 3], ["a", "b"]))
|
||||
assert result1 == [
|
||||
(1, "a"),
|
||||
(1, "b"),
|
||||
(2, "a"),
|
||||
(2, "b"),
|
||||
(3, "a"),
|
||||
(3, "b"),
|
||||
]
|
||||
|
||||
# check that the tuples get flattened
|
||||
result2 = list(flat_product([(1, 2), (3, 4)], ["a", "b"], [(5, 6)]))
|
||||
assert result2 == [
|
||||
(1, 2, "a", 5, 6),
|
||||
(1, 2, "b", 5, 6),
|
||||
(3, 4, "a", 5, 6),
|
||||
(3, 4, "b", 5, 6),
|
||||
]
|
||||
|
||||
@ -1507,7 +1507,7 @@ def scaled_fp8_quant(
|
||||
output, input, scale, scale_ub
|
||||
)
|
||||
else:
|
||||
scale = torch.empty(1, device=input.device, dtype=torch.float32)
|
||||
scale = torch.empty((1, 1), device=input.device, dtype=torch.float32)
|
||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
assert scale.numel() == 1, f"{scale.shape}"
|
||||
|
||||
@ -17,10 +17,14 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -41,11 +45,8 @@ else:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
|
||||
class BasePattern:
|
||||
@ -669,33 +670,24 @@ class AllReduceRMSNormPattern(BasePattern):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
rms_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
|
||||
return [input, rms_result, weight]
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor
|
||||
):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor):
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=rms_result,
|
||||
input=allreduce_output,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
# rms_result, allreduce_output
|
||||
return rms[1], allreduce_output
|
||||
rms = self.rmsnorm_matcher(allreduce_output, weight)
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor
|
||||
):
|
||||
return rms, allreduce_output
|
||||
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor):
|
||||
residual = torch.zeros_like(input)
|
||||
rms_result = torch.empty_like(input)
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
@ -733,29 +725,19 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
return [
|
||||
residual,
|
||||
input,
|
||||
weight,
|
||||
]
|
||||
input, residual, weight = self.rmsnorm_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [residual, input.to(self.dtype), weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor):
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=allreduce_output,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
# input, residual
|
||||
return rms[1], rms[2]
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
return rms, residual
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||
@ -779,6 +761,18 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
# Same pattern, but only return the output and not residual
|
||||
# (helpful for end of graph where residual is not used again)
|
||||
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
|
||||
|
||||
pm.register_replacement(
|
||||
first_return_only(pattern),
|
||||
first_return_only(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
"""
|
||||
@ -799,60 +793,37 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.quant_dtype = torch.float8_e4m3fn
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs():
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
rmsnorm_result = torch.empty(
|
||||
[1, 8, 4], device=self.device, dtype=self.dtype
|
||||
)
|
||||
quant_result = torch.empty(
|
||||
[1, 8, 4], device=self.device, dtype=self.quant_dtype
|
||||
)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, rmsnorm_result, quant_result, weight, scale]
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight, scale]
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rmsnorm_out_tuple = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=rmsnorm_result,
|
||||
input=all_reduce,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP8_QUANT_OP,
|
||||
result=quant_result,
|
||||
input=rmsnorm_out_tuple[1],
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output
|
||||
return quant_out_tuple[1], all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=result_rms,
|
||||
quant_out=quant_result,
|
||||
quant_out=result_quant,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
@ -892,64 +863,42 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
self.allreduce_params = allreduce_params
|
||||
self.quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs():
|
||||
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
input, residual, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty(
|
||||
[4, 4], device=self.device, dtype=self.quant_dtype
|
||||
)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [
|
||||
quant_result,
|
||||
residual,
|
||||
input,
|
||||
weight,
|
||||
scale,
|
||||
]
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [residual, input.to(self.dtype), weight, scale]
|
||||
|
||||
def pattern(
|
||||
quant_result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
|
||||
fused_add_rmsnorm_out_tuple = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=allreduce_output,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP8_QUANT_OP,
|
||||
result=quant_result,
|
||||
input=fused_add_rmsnorm_out_tuple[1],
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output
|
||||
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2]
|
||||
return quant, res
|
||||
|
||||
def replacement(
|
||||
quant_result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=quant_result,
|
||||
quant_out=result_quant,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
@ -986,14 +935,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs():
|
||||
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
|
||||
|
||||
rmsnorm_result = torch.empty(
|
||||
[1, 16, 16], device=self.device, dtype=self.dtype
|
||||
)
|
||||
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
|
||||
input_global_scale = torch.empty(
|
||||
[1, 1], device=self.device, dtype=torch.float32
|
||||
@ -1001,36 +947,21 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
weight = torch.empty([16], device=self.device, dtype=self.dtype)
|
||||
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
|
||||
|
||||
return [
|
||||
input,
|
||||
rmsnorm_result,
|
||||
quant_result,
|
||||
weight,
|
||||
input_global_scale,
|
||||
output_scale,
|
||||
]
|
||||
return [input, quant_result, weight, input_global_scale, output_scale]
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
):
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rmsnorm_out_tuple = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=rmsnorm_result,
|
||||
input=all_reduce,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rmsnorm_out_tuple[1],
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
)
|
||||
@ -1040,13 +971,13 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
):
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
@ -1090,6 +1021,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs():
|
||||
@ -1121,28 +1053,17 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
input_global_scale: torch.Tensor,
|
||||
):
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
|
||||
fused_add_rmsnorm_out_tuple = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=allreduce_output,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=fused_add_rmsnorm_out_tuple[1],
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return (
|
||||
quant_out_tuple[1],
|
||||
fused_add_rmsnorm_out_tuple[2],
|
||||
quant_out_tuple[2],
|
||||
)
|
||||
return quant_out_tuple[1], residual, quant_out_tuple[2]
|
||||
|
||||
def replacement(
|
||||
quant_result: torch.Tensor,
|
||||
|
||||
@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -92,13 +93,19 @@ class RMSNormQuantPattern:
|
||||
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
|
||||
assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}"
|
||||
self.QUANT_OP = QUANT_OPS[key.quant]
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
|
||||
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
self.rmsnorm_matcher = (
|
||||
MatcherRMSNorm(epsilon)
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(key.quant)
|
||||
|
||||
|
||||
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
|
||||
@ -112,34 +119,18 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
# Cannot use methods, as the self argument affects tracing
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at1[1], scale=scale
|
||||
)
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
return self.quant_matcher(result_rms, scale)[0]
|
||||
|
||||
# result
|
||||
return at2[1]
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_dtype
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@ -153,12 +144,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # result_rms
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1), # scale
|
||||
# input, weight
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
pattern(*inputs)
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
|
||||
@ -175,33 +165,27 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at[1], scale=scale
|
||||
)
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, _ = self.quant_matcher(result_rms, scale)
|
||||
|
||||
# result, residual
|
||||
return at1[1], at[2]
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@ -216,11 +200,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1), # scale
|
||||
# input, weight, residual
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
@ -248,34 +230,18 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
|
||||
)
|
||||
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor):
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
# result, scale
|
||||
return at2[1], at2[2]
|
||||
return self.quant_matcher(result_rms)
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@ -290,18 +256,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # result_rms
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
@ -323,34 +281,21 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
|
||||
)
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
# result, residual, scale
|
||||
return at1[1], at[2], at1[2]
|
||||
return result, residual, scale
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@ -365,18 +310,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
@ -396,23 +333,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
pass_name="rmsnorm_quant_fusion_pass"
|
||||
)
|
||||
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
|
||||
@ -2,9 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
@ -20,7 +22,9 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .fx_utils import is_func
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherQuantFP8
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -66,9 +70,13 @@ class AttentionQuantPattern(ABC):
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(process_fx, trace_fn):
|
||||
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
|
||||
def wrapped(*args, **kwargs):
|
||||
return process_fx(trace_fn(*args, **kwargs))
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
|
||||
return gm
|
||||
|
||||
return wrapped
|
||||
|
||||
@ -77,7 +85,20 @@ class AttentionQuantPattern(ABC):
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
|
||||
@staticmethod
|
||||
def remove_noop_permutes(gm: torch.fx.GraphModule):
|
||||
for node in gm.graph.nodes:
|
||||
if not is_func(node, torch.ops.aten.permute.default):
|
||||
continue
|
||||
|
||||
dims = node.args[1]
|
||||
if any(dim != i for i, dim in enumerate(dims)):
|
||||
continue
|
||||
|
||||
# this is now an identity op, remove
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass):
|
||||
if self.layer.impl.fused_output_quant_supported(self.quant_key):
|
||||
@ -108,6 +129,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
|
||||
)
|
||||
super().__init__(layer, quant_key, dtype)
|
||||
self.quant_matcher = MatcherQuantFP8(quant_key)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
@ -115,7 +137,6 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
@ -131,17 +152,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale
|
||||
)
|
||||
return at2[1]
|
||||
|
||||
return self.quant_matcher(attn_out_view, scale)[0]
|
||||
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
# attn output in quant_dtype
|
||||
@ -164,13 +182,10 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
inputs = [
|
||||
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # q
|
||||
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # k
|
||||
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # v
|
||||
self.empty(
|
||||
5, self.num_heads, self.head_size, dtype=self.dtype
|
||||
), # attn_output
|
||||
self.empty_quant(5, self.num_heads * self.head_size), # quant_output
|
||||
self.empty(5, self.num_heads, self.head_size), # q
|
||||
self.empty(5, self.num_heads, self.head_size), # k
|
||||
self.empty(5, self.num_heads, self.head_size), # v
|
||||
self.empty(5, self.num_heads, self.head_size), # attn_output
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
@ -179,7 +194,9 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
@ -279,7 +296,9 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@ from collections.abc import Iterable, Iterator
|
||||
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._ops import OpOverload
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
|
||||
|
||||
def is_func(node: fx.Node, target) -> bool:
|
||||
@ -64,7 +64,17 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
||||
|
||||
|
||||
# An auto-functionalization-aware utility for finding nodes with a specific op
|
||||
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
|
||||
# Also handles op overload packets and finds all overloads
|
||||
def find_op_nodes(
|
||||
op: OpOverload | OpOverloadPacket, graph: fx.Graph
|
||||
) -> Iterator[fx.Node]:
|
||||
if isinstance(op, OpOverloadPacket):
|
||||
for overload in op.overloads():
|
||||
overload_op = getattr(op, overload)
|
||||
yield from find_op_nodes(overload_op, graph)
|
||||
return
|
||||
|
||||
assert isinstance(op, OpOverload)
|
||||
if not op._schema.is_mutable:
|
||||
yield from graph.find_nodes(op="call_function", target=op)
|
||||
|
||||
|
||||
208
vllm/compilation/matcher_utils.py
Normal file
208
vllm/compilation/matcher_utils.py
Normal file
@ -0,0 +1,208 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
_normalize_quant_group_shape,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
||||
|
||||
|
||||
class MatcherCustomOp(ABC):
|
||||
def __init__(self, enabled: bool):
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device = config.device_config.device if config.device_config else None
|
||||
|
||||
self.enabled = enabled
|
||||
self.forward = self.forward_custom if enabled else self.forward_native
|
||||
|
||||
@abstractmethod
|
||||
def forward_custom(self, *args, **kws):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_native(self, *args, **kws):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kws):
|
||||
return self.forward(*args, **kws)
|
||||
|
||||
def empty(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
|
||||
|
||||
def empty_f32(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
"""Utility for inputs to the pattern"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MatcherRMSNorm(MatcherCustomOp):
|
||||
def __init__(self, epsilon: float, enabled: bool | None = None):
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
|
||||
def inputs(self):
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
return [input, weight]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
result = torch.empty_like(input)
|
||||
_, result = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight
|
||||
)
|
||||
|
||||
|
||||
class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
def __init__(self, epsilon: float, enabled: bool | None = None):
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
|
||||
def inputs(self):
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(5, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
_, result, residual = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return result, residual
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
|
||||
)
|
||||
|
||||
|
||||
class MatcherQuantFP8(MatcherCustomOp):
|
||||
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
|
||||
if enabled is None:
|
||||
enabled = QuantFP8.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.quant_key = quant_key
|
||||
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[quant_key]
|
||||
|
||||
assert quant_key.dtype == current_platform.fp8_dtype(), (
|
||||
"Only QuantFP8 supported by"
|
||||
)
|
||||
assert quant_key.scale2 is None
|
||||
self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_key.dtype
|
||||
)
|
||||
|
||||
if self.quant_key.scale.static:
|
||||
assert scale is not None
|
||||
_, result = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=input, scale=scale
|
||||
)
|
||||
return result, scale
|
||||
else:
|
||||
assert scale is None
|
||||
scale = self.make_scale(input)
|
||||
_, result, scale = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
|
||||
)
|
||||
return result, scale
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.quant_fp8(input, scale)
|
||||
|
||||
def make_scale(self, input: torch.Tensor):
|
||||
normalized_group_shape = _normalize_quant_group_shape(
|
||||
input, self.quant_key.scale.group_shape
|
||||
)
|
||||
scale_shape = (
|
||||
input.shape[0] // normalized_group_shape[0],
|
||||
input.shape[1] // normalized_group_shape[1],
|
||||
)
|
||||
|
||||
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16)
|
||||
if self.quant_key.scale.static:
|
||||
return [input, self.empty_f32(1, 1)]
|
||||
|
||||
return [input]
|
||||
@ -22,6 +22,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
import depyf
|
||||
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
logger.debug("Dumping depyf output to %s", path)
|
||||
global context_manager
|
||||
context_manager = depyf.prepare_debug(path.as_posix())
|
||||
context_manager.__enter__()
|
||||
|
||||
@ -5,7 +5,7 @@ import functools
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import set_env_var
|
||||
@ -88,27 +88,30 @@ class PostGradPassManager(CustomGraphPass):
|
||||
|
||||
def configure(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
if self.pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.enable_async_tp:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
# Set the current vllm config to allow tracing CustomOp instances
|
||||
with set_current_vllm_config(config, check_compile=False):
|
||||
if self.pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if self.pass_config.enable_fi_allreduce_fusion:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.enable_async_tp:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
if self.pass_config.enable_fusion:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
if self.pass_config.enable_fi_allreduce_fusion:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
if self.pass_config.enable_fusion:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
# [HACK: Bug with Inductor graph partition and torch.compile cache]
|
||||
# In PyTorch 2.9, torch.compile has a bug where the graph
|
||||
|
||||
@ -128,7 +128,8 @@ class VllmPatternMatcherPass(VllmInductorPass):
|
||||
f" please add to dump_patterns if there are any errors.\n\n"
|
||||
f"from torch._higher_order_ops.auto_functionalize import "
|
||||
f"auto_functionalized as auto_functionalized\n"
|
||||
f"from torch._inductor.pattern_matcher import *",
|
||||
f"from torch._inductor.pattern_matcher import *\n"
|
||||
f"vllm = torch.ops.vllm",
|
||||
file=f,
|
||||
)
|
||||
|
||||
|
||||
@ -178,14 +178,11 @@ class RMSNorm(CustomOp):
|
||||
self.variance_size_override = (
|
||||
None if var_hidden_size == hidden_size else var_hidden_size
|
||||
)
|
||||
weight_dtype = dtype or torch.get_default_dtype()
|
||||
self.has_weight = has_weight
|
||||
if dtype is not None:
|
||||
self.weight = torch.ones(hidden_size, dtype=dtype)
|
||||
else:
|
||||
self.weight = torch.ones(hidden_size)
|
||||
self.weight = torch.ones(hidden_size, dtype=weight_dtype)
|
||||
if self.has_weight:
|
||||
self.weight = nn.Parameter(self.weight)
|
||||
weight_dtype = self.weight.data.dtype
|
||||
|
||||
if current_platform.is_rocm():
|
||||
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
|
||||
@ -195,46 +192,68 @@ class RMSNorm(CustomOp):
|
||||
with_fused_add=True, dtype=weight_dtype
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_static(
|
||||
x: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
hidden_size: int,
|
||||
orig_dtype: torch.dtype,
|
||||
weight: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
variance_size_override: int | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
# residual promoted f16->f32 automatically,
|
||||
# otherwise Inductor eliminates the casts to and from f16,
|
||||
# increasing memory usage (and complicating pattern matching)
|
||||
x = x + residual
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
if x.shape[-1] != hidden_size:
|
||||
raise ValueError(
|
||||
f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
|
||||
)
|
||||
|
||||
if variance_size_override is None:
|
||||
x_var = x
|
||||
else:
|
||||
if hidden_size < variance_size_override:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be at least "
|
||||
f"{variance_size_override}, but found: {hidden_size}"
|
||||
)
|
||||
|
||||
x_var = x[:, :, :variance_size_override]
|
||||
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
|
||||
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||
x = x.to(orig_dtype)
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
hidden_size = x.shape[-1]
|
||||
if hidden_size != self.hidden_size:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be "
|
||||
f"{self.hidden_size}, but found: {hidden_size}"
|
||||
)
|
||||
|
||||
if self.variance_size_override is None:
|
||||
x_var = x
|
||||
else:
|
||||
if hidden_size < self.variance_size_override:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be at least "
|
||||
f"{self.variance_size_override}, but found: {hidden_size}"
|
||||
)
|
||||
|
||||
x_var = x[:, :, : self.variance_size_override]
|
||||
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype)
|
||||
if self.has_weight:
|
||||
x = x * self.weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
return self.forward_static(
|
||||
x,
|
||||
self.variance_epsilon,
|
||||
self.hidden_size,
|
||||
x.dtype,
|
||||
self.weight.data if self.has_weight else None,
|
||||
residual,
|
||||
self.variance_size_override,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user