[6/N] torch.compile rollout to users (#10437)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-19 10:09:03 -08:00
committed by GitHub
parent fd9f124971
commit 803f37eaaa
15 changed files with 129 additions and 141 deletions

View File

@ -1,5 +0,0 @@
{
"use_cudagraph": true,
"non_cudagraph_ops": ["silly.attention"],
"cudagraph_copy_inputs": true
}

View File

@ -2,7 +2,6 @@
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
import os
import torch
from torch import nn
@ -11,7 +10,7 @@ from torch.library import Library
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationLevel, VllmConfig
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op
@ -77,12 +76,12 @@ class SillyModel(nn.Module):
def test_simple_piecewise_compile():
directory = os.path.dirname(__file__)
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
vllm_config = VllmConfig()
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
cudagraph_copy_inputs=True,
))
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
@ -109,6 +108,3 @@ def test_simple_piecewise_compile():
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
# clean up to avoid side effects for other tests
del os.environ["VLLM_TORCH_COMPILE_CONFIG"]

View File

@ -6,7 +6,6 @@ This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
"""
import os
from dataclasses import dataclass
from typing import Optional, Tuple
@ -18,7 +17,7 @@ from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_compilation_config, set_current_vllm_config
from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
@ -254,23 +253,17 @@ def run_model(llama_config,
split_attn: bool = False) -> torch.Tensor:
if use_compile:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
CompilationLevel.PIECEWISE)
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
)
if split_attn:
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
))
else:
set_compilation_config(CompilationConfig(use_cudagraph=True, ))
compilation_config.non_cudagraph_ops = ["silly.attention"]
else:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
CompilationLevel.NO_COMPILATION)
set_compilation_config(None)
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )
vllm_config = VllmConfig()
vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
@ -288,10 +281,6 @@ def run_model(llama_config,
input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])
# manual cleanup
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
set_compilation_config(None)
output = output.cpu()
if llama_config.tractable_init:
@ -361,7 +350,6 @@ def test_toy_llama():
@torch.inference_mode
def benchmark():
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
from triton.testing import do_bench
# similar to llama 3.1-8B
@ -387,15 +375,16 @@ def benchmark():
for piecewise in [False, True]:
if piecewise:
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
))
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
)
else:
set_compilation_config(None)
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, )
vllm_config = VllmConfig()
vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,

View File

@ -96,31 +96,36 @@ def test_compile_correctness(test_setting: TestSetting):
final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \
["-tp", str(tp_size)]
all_args: List[List[str]] = []
all_envs: List[Optional[Dict[str, str]]] = []
for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE,
]:
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})
all_args.append(final_args + ["-O", str(level)])
all_envs.append({})
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model, [final_args] * 2,
model,
all_args,
all_envs,
method=method if method != "generate" else "generate_close")
all_envs.clear()
all_args.clear()
for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
]:
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})
all_args.append(final_args + ["-O", str(level)])
all_envs.append({})
if level != CompilationLevel.DYNAMO_ONCE and not fullgraph:
# "DYNAMO_ONCE" will always use fullgraph
all_envs[-1][
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore
compare_all_settings(model, [final_args] * 3, all_envs, method=method)
compare_all_settings(model, all_args * 3, all_envs, method=method)

View File

@ -4,7 +4,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.config import CompilationLevel
from vllm.config import CompilationConfig, CompilationLevel
from vllm.platforms import current_platform
TEST_MODELS = [
@ -65,7 +65,6 @@ def check_full_graph_support(model,
optimization_level,
tp_size=1):
# make sure these models can be captured in full graph mode
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
# The base meta llama uses too much memory.
@ -86,6 +85,7 @@ def check_full_graph_support(model,
enforce_eager=True,
tensor_parallel_size=tp_size,
disable_custom_all_reduce=True,
compilation_config=CompilationConfig(level=optimization_level),
**model_kwargs)
outputs = llm.generate(prompts, sampling_params)