[Feature] support sequence parallelism using compilation pass (#16155)

Signed-off-by: cascade812 <cascade812@outlook.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
cascade
2025-04-27 06:29:35 -07:00
committed by GitHub
parent ed7a29d9f8
commit 690fe019f0
21 changed files with 1072 additions and 44 deletions

View File

@ -10,7 +10,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig
from vllm.config import CompilationConfig, VllmConfig
from .backend import TestBackend
@ -49,13 +49,15 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
do_fusion: bool):
torch.set_default_device("cuda")
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_noop=True)
noop_pass = NoOpEliminationPass(config)
fusion_pass = FusionPass.instance(config)
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config= \
CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
func_pass = FixFunctionalizationPass(config)
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)

View File

@ -77,12 +77,13 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
vllm_config.compilation_config.pass_config = \
CompilationConfig.PassConfig(enable_fusion=True,
enable_noop=True)
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_noop=True)
noop_pass = NoOpEliminationPass(config)
fusion_pass = FusionPass.instance(config)
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
backend = TestBackend(noop_pass, fusion_pass)
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)

View File

@ -6,7 +6,7 @@ import torch
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import CompilationConfig
from vllm.config import VllmConfig
# dummy custom pass that doesn't inherit
@ -16,7 +16,7 @@ def simple_callable(graph: torch.fx.Graph):
# Should fail to add directly to the pass manager
def test_bad_callable():
config = CompilationConfig().pass_config
config = VllmConfig()
pass_manager = PostGradPassManager()
pass_manager.configure(config)
@ -43,7 +43,7 @@ class ProperPass(InductorPass):
],
)
def test_pass_manager_uuid(callable):
config = CompilationConfig().pass_config
config = VllmConfig()
pass_manager = PostGradPassManager()
pass_manager.configure(config)
@ -64,7 +64,8 @@ def test_pass_manager_uuid(callable):
# UUID should be different due to config change
config2 = copy.deepcopy(config)
config2.enable_fusion = not config2.enable_fusion
config2.compilation_config.pass_config.enable_fusion = not \
config2.compilation_config.pass_config.enable_fusion
pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2)
pass_manager3.add(callable)

View File

@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import vllm.envs as envs
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
find_specified_fn,
find_specified_fn_maybe, is_func)
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
from ..utils import multi_gpu_test
from .backend import TestBackend
OPS_IN_MODEL_BEFORE = [
torch.ops.vllm.all_reduce.default,
]
OPS_IN_MODEL_AFTER = [
torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default,
]
OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default]
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
class TestModel(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size)))
self.norm = RMSNorm(hidden_size, 1e-05)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
def forward(self, hidden_states, residual):
"""
Forward pass implementing the operations in the FX graph
Args:
hidden_states: Input tensor
residual: Residual tensor from previous layer
Returns:
Tuple containing the output tensor
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
#matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)
# Tensor parallel all-reduce
all_reduce = tensor_model_parallel_all_reduce(mm)
# layer normalization
norm_output, residual_output = self.norm(all_reduce, residual)
return norm_output, residual_output
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_sequence_parallelism_pass(batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
num_processes = 2
def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch.multiprocessing.spawn(fn,
args=(num_processes, batch_size, seq_len,
hidden_size, dtype),
nprocs=nprocs)
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
batch_size: int, seq_len: int,
hidden_size: int,
dtype: torch.dtype):
current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=CompilationConfig.PassConfig(
enable_sequence_parallelism=True, ), )
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype=dtype,
seed=42)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
backend_no_func = TestBackend(sequence_parallelism_pass)
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(sequence_parallelism_pass, func_pass)
model = TestModel(hidden_size, hidden_size * 2)
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
compiled_model_no_func(hidden_states, residual)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
# Check substitution worked
pre_nodes = backend_no_func.graph_pre_pass.nodes
post_nodes = backend_no_func.graph_post_pass.nodes
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
for op in OPS_IN_MODEL_BEFORE:
find_specified_fn(pre_nodes, op)
for op in OPS_IN_MODEL_AFTER:
assert find_specified_fn_maybe(pre_nodes, op) is None
# In post-nodes, reduce scatter and all gather should be there,
# all reduce should not
for op in OPS_IN_MODEL_AFTER:
find_specified_fn(post_nodes, op)
for op in OPS_IN_MODEL_BEFORE:
assert find_specified_fn_maybe(post_nodes, op) is None
# check if the functionalization pass is applied
for op in OPS_IN_MODEL:
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
op) is None # noqa: E501
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in OPS_IN_MODEL:
if is_func(node, op):
found[op] = True
assert all(found[op] for op in OPS_IN_MODEL)

View File

@ -14,7 +14,8 @@ import torch
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from ..utils import init_test_distributed_environment, multi_process_parallel
@ -47,6 +48,34 @@ def all_reduce_test_worker(
torch.testing.assert_close(t, expected)
@ray.remote(num_gpus=1, max_calls=1)
def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int,
pp_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
num_elements = 8
all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
(r + 1) for r in range(tp_size)
]
index = rank % tp_size
partition_size = num_elements // tp_size
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
expected = all_reduce[index * partition_size:(index + 1) * partition_size]
t = all_tensors[index]
t = tensor_model_parallel_reduce_scatter(t, 0)
torch.testing.assert_close(t, expected)
@ray.remote(num_gpus=1, max_calls=1)
def all_gather_test_worker(
monkeypatch: pytest.MonkeyPatch,

View File

@ -0,0 +1,296 @@
# SPDX-License-Identifier: Apache-2.0
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
important to set the distributed backend to "mp" to avoid Ray scheduling
all workers in a node other than the head node, which can cause the test
to fail.
"""
import json
import os
from dataclasses import dataclass
from typing import Literal, NamedTuple, Optional
import pytest
from vllm.config import TaskOption
from vllm.logger import init_logger
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test
logger = init_logger("test_sequence_parallel")
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
class ParallelSetup(NamedTuple):
tp_size: int
sp_enabled: bool
eager_mode: bool
chunked_prefill: bool
class SPTestOptions(NamedTuple):
multi_node_only: bool
load_format: Optional[str] = None
@dataclass
class SPTestSettings:
parallel_setups: list[ParallelSetup]
# NOTE: the length of distributed_backends and
# vllm_major_versions should be the same, and they
# are first zipped together to iterate over all
# test settings.
distributed_backends: list[str]
# vllm major version: "0" for V0, "1" for V1
vllm_major_versions: list[str]
task: TaskOption
test_options: SPTestOptions
def __post_init__(self):
if len(self.distributed_backends) != len(self.vllm_major_versions):
raise ValueError(
f"Length mismatch: distributed_backends "
f"({len(self.distributed_backends)}) != "
f"vllm_major_versions ({len(self.vllm_major_versions)})")
@staticmethod
def detailed(
*,
tp_base: int = 2,
multi_node_only: bool = False,
task: TaskOption = "auto",
load_format: Optional[str] = None,
):
return SPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=True)
],
distributed_backends=["mp", "ray"],
vllm_major_versions=["1", "1"],
task=task,
test_options=SPTestOptions(multi_node_only=multi_node_only,
load_format=load_format),
)
@staticmethod
def fast(
*,
tp_base: int = 2,
task: TaskOption = "auto",
multi_node_only: bool = False,
load_format: Optional[str] = None,
):
return SPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
],
distributed_backends=["mp", "ray"],
vllm_major_versions=["1", "1"],
task=task,
test_options=SPTestOptions(multi_node_only=multi_node_only,
load_format=load_format),
)
def iter_params(self, model_id: str):
opts = self.test_options
for parallel_setup in self.parallel_setups:
for backend, vllm_major_version in zip(self.distributed_backends,
self.vllm_major_versions):
yield (model_id, parallel_setup, backend, vllm_major_version,
self.task, opts)
def _compare_sp(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: SPTestOptions,
num_gpus_available: int,
*,
method: Literal["generate", "encode"],
is_multimodal: bool,
):
(
tp_size,
sp_enabled,
eager_mode,
chunked_prefill,
) = parallel_setup
multi_node_only, load_format = test_options
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
trust_remote_code = model_info.trust_remote_code
tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides
if load_format == "dummy":
# Avoid OOM
text_overrides = {
"num_hidden_layers": 4,
"hidden_size": 512,
"intermediate_size": 800,
"num_attention_heads": 4,
"num_key_value_heads": 1,
}
if is_multimodal:
hf_overrides.update({"text_config": text_overrides})
else:
hf_overrides.update(text_overrides)
else:
model_info.check_available_online(on_fail="skip")
pp_size = 1
if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
if multi_node_only and not VLLM_MULTI_NODE:
pytest.skip("Not in multi-node setting")
common_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"2048",
"--max-num-seqs",
"8",
]
if chunked_prefill:
common_args.append("--enable-chunked-prefill")
if eager_mode:
common_args.append("--enforce-eager")
if task != "auto":
common_args.extend(["--task", task])
if trust_remote_code:
common_args.append("--trust-remote-code")
if tokenizer_mode:
common_args.extend(["--tokenizer-mode", tokenizer_mode])
if load_format:
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
compilation_config = {
'level': 3,
'custom_ops': ["+rms_norm"],
'compile_sizes': [4, 8],
'splitting_ops': [],
'pass_config': {
'enable_sequence_parallism': sp_enabled,
'enable_noop': True,
'enable_fusion': True,
},
}
tp_sp_env = tp_env = {
"VLLM_USE_V1": vllm_major_version,
}
tp_sp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--distributed-executor-backend",
distributed_backend,
"--compilation_config",
str(compilation_config),
]
tp_env = {
"VLLM_USE_V1": vllm_major_version,
}
tp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--distributed-executor-backend",
"mp",
]
try:
compare_two_settings(model_id,
tp_sp_args,
tp_args,
tp_sp_env,
tp_env,
method=method)
except Exception:
testing_ray_compiled_graph = tp_sp_env is not None
if testing_ray_compiled_graph and vllm_major_version == "0":
# Ray Compiled Graph tests are flaky for V0,
# so we don't want to fail the test
logger.exception("Ray Compiled Graph tests failed")
else:
raise
SP_TEXT_GENERATION_MODELS = {
# [Decoder-only]
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(),
}
SP_TEST_MODELS = [
# TODO support other models
# [LANGUAGE GENERATION]
"meta-llama/Llama-3.2-1B-Instruct",
]
@pytest.mark.parametrize(
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
"task", "test_options"),
[
params for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_id)
if model_id in SP_TEST_MODELS
],
)
@create_new_process_for_each_test()
def test_tp_sp_generation(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: SPTestOptions,
num_gpus_available,
):
_compare_sp(model_id,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
method="generate",
is_multimodal=False)