[6/N] torch.compile rollout to users (#10437)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@ -1,5 +0,0 @@
|
||||
{
|
||||
"use_cudagraph": true,
|
||||
"non_cudagraph_ops": ["silly.attention"],
|
||||
"cudagraph_copy_inputs": true
|
||||
}
|
||||
@ -2,7 +2,6 @@
|
||||
Test the piecewise compilation with a simple model so that we
|
||||
can exactly calculate the expected output and side effects.
|
||||
"""
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -11,7 +10,7 @@ from torch.library import Library
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@ -77,12 +76,12 @@ class SillyModel(nn.Module):
|
||||
|
||||
def test_simple_piecewise_compile():
|
||||
|
||||
directory = os.path.dirname(__file__)
|
||||
config = os.path.join(directory, "piecewise_compilation_config.json")
|
||||
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
non_cudagraph_ops=["silly.attention"],
|
||||
cudagraph_copy_inputs=True,
|
||||
))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SillyModel(vllm_config=vllm_config, prefix='')
|
||||
|
||||
@ -109,6 +108,3 @@ def test_simple_piecewise_compile():
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
|
||||
# clean up to avoid side effects for other tests
|
||||
del os.environ["VLLM_TORCH_COMPILE_CONFIG"]
|
||||
|
||||
@ -6,7 +6,6 @@ This is a tractable model, the weights and computation are specially designed
|
||||
if the config `tractable_init` is set to True. Otherwise, the weights are
|
||||
initialized randomly with a fixed seed.
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@ -18,7 +17,7 @@ from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.plugins import set_compilation_config, set_current_vllm_config
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
@ -254,23 +253,17 @@ def run_model(llama_config,
|
||||
split_attn: bool = False) -> torch.Tensor:
|
||||
|
||||
if use_compile:
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
|
||||
CompilationLevel.PIECEWISE)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
)
|
||||
if split_attn:
|
||||
set_compilation_config(
|
||||
CompilationConfig(
|
||||
use_cudagraph=True,
|
||||
non_cudagraph_ops=["silly.attention"],
|
||||
))
|
||||
else:
|
||||
set_compilation_config(CompilationConfig(use_cudagraph=True, ))
|
||||
compilation_config.non_cudagraph_ops = ["silly.attention"]
|
||||
else:
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
|
||||
CompilationLevel.NO_COMPILATION)
|
||||
set_compilation_config(None)
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION, )
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config = VllmConfig(compilation_config=compilation_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = LlamaModel(config=llama_config,
|
||||
vllm_config=vllm_config,
|
||||
@ -288,10 +281,6 @@ def run_model(llama_config,
|
||||
input_ids[:2].zero_()
|
||||
output = model(input_ids[:2], positions[:2])
|
||||
|
||||
# manual cleanup
|
||||
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
|
||||
set_compilation_config(None)
|
||||
|
||||
output = output.cpu()
|
||||
|
||||
if llama_config.tractable_init:
|
||||
@ -361,7 +350,6 @@ def test_toy_llama():
|
||||
|
||||
@torch.inference_mode
|
||||
def benchmark():
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||
from triton.testing import do_bench
|
||||
|
||||
# similar to llama 3.1-8B
|
||||
@ -387,15 +375,16 @@ def benchmark():
|
||||
|
||||
for piecewise in [False, True]:
|
||||
if piecewise:
|
||||
set_compilation_config(
|
||||
CompilationConfig(
|
||||
use_cudagraph=True,
|
||||
non_cudagraph_ops=["silly.attention"],
|
||||
))
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
non_cudagraph_ops=["silly.attention"],
|
||||
)
|
||||
else:
|
||||
set_compilation_config(None)
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE, )
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config = VllmConfig(compilation_config=compilation_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = LlamaModel(config=llama_config,
|
||||
vllm_config=vllm_config,
|
||||
|
||||
@ -96,31 +96,36 @@ def test_compile_correctness(test_setting: TestSetting):
|
||||
final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \
|
||||
["-tp", str(tp_size)]
|
||||
|
||||
all_args: List[List[str]] = []
|
||||
all_envs: List[Optional[Dict[str, str]]] = []
|
||||
|
||||
for level in [
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
CompilationLevel.PIECEWISE,
|
||||
]:
|
||||
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})
|
||||
all_args.append(final_args + ["-O", str(level)])
|
||||
all_envs.append({})
|
||||
|
||||
# inductor will change the output, so we only compare if the output
|
||||
# is close, not exactly the same.
|
||||
compare_all_settings(
|
||||
model, [final_args] * 2,
|
||||
model,
|
||||
all_args,
|
||||
all_envs,
|
||||
method=method if method != "generate" else "generate_close")
|
||||
all_envs.clear()
|
||||
all_args.clear()
|
||||
|
||||
for level in [
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
CompilationLevel.DYNAMO_AS_IS,
|
||||
CompilationLevel.DYNAMO_ONCE,
|
||||
]:
|
||||
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})
|
||||
all_args.append(final_args + ["-O", str(level)])
|
||||
all_envs.append({})
|
||||
if level != CompilationLevel.DYNAMO_ONCE and not fullgraph:
|
||||
# "DYNAMO_ONCE" will always use fullgraph
|
||||
all_envs[-1][
|
||||
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore
|
||||
|
||||
compare_all_settings(model, [final_args] * 3, all_envs, method=method)
|
||||
compare_all_settings(model, all_args * 3, all_envs, method=method)
|
||||
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
TEST_MODELS = [
|
||||
@ -65,7 +65,6 @@ def check_full_graph_support(model,
|
||||
optimization_level,
|
||||
tp_size=1):
|
||||
# make sure these models can be captured in full graph mode
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
|
||||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
|
||||
|
||||
# The base meta llama uses too much memory.
|
||||
@ -86,6 +85,7 @@ def check_full_graph_support(model,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
disable_custom_all_reduce=True,
|
||||
compilation_config=CompilationConfig(level=optimization_level),
|
||||
**model_kwargs)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
Reference in New Issue
Block a user