[6/N] torch.compile rollout to users (#10437)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@ -1,5 +0,0 @@
|
||||
{
|
||||
"use_cudagraph": true,
|
||||
"non_cudagraph_ops": ["silly.attention"],
|
||||
"cudagraph_copy_inputs": true
|
||||
}
|
||||
@ -2,7 +2,6 @@
|
||||
Test the piecewise compilation with a simple model so that we
|
||||
can exactly calculate the expected output and side effects.
|
||||
"""
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -11,7 +10,7 @@ from torch.library import Library
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@ -77,12 +76,12 @@ class SillyModel(nn.Module):
|
||||
|
||||
def test_simple_piecewise_compile():
|
||||
|
||||
directory = os.path.dirname(__file__)
|
||||
config = os.path.join(directory, "piecewise_compilation_config.json")
|
||||
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
non_cudagraph_ops=["silly.attention"],
|
||||
cudagraph_copy_inputs=True,
|
||||
))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SillyModel(vllm_config=vllm_config, prefix='')
|
||||
|
||||
@ -109,6 +108,3 @@ def test_simple_piecewise_compile():
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
|
||||
# clean up to avoid side effects for other tests
|
||||
del os.environ["VLLM_TORCH_COMPILE_CONFIG"]
|
||||
|
||||
@ -6,7 +6,6 @@ This is a tractable model, the weights and computation are specially designed
|
||||
if the config `tractable_init` is set to True. Otherwise, the weights are
|
||||
initialized randomly with a fixed seed.
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@ -18,7 +17,7 @@ from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.plugins import set_compilation_config, set_current_vllm_config
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
@ -254,23 +253,17 @@ def run_model(llama_config,
|
||||
split_attn: bool = False) -> torch.Tensor:
|
||||
|
||||
if use_compile:
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
|
||||
CompilationLevel.PIECEWISE)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
)
|
||||
if split_attn:
|
||||
set_compilation_config(
|
||||
CompilationConfig(
|
||||
use_cudagraph=True,
|
||||
non_cudagraph_ops=["silly.attention"],
|
||||
))
|
||||
else:
|
||||
set_compilation_config(CompilationConfig(use_cudagraph=True, ))
|
||||
compilation_config.non_cudagraph_ops = ["silly.attention"]
|
||||
else:
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
|
||||
CompilationLevel.NO_COMPILATION)
|
||||
set_compilation_config(None)
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION, )
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config = VllmConfig(compilation_config=compilation_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = LlamaModel(config=llama_config,
|
||||
vllm_config=vllm_config,
|
||||
@ -288,10 +281,6 @@ def run_model(llama_config,
|
||||
input_ids[:2].zero_()
|
||||
output = model(input_ids[:2], positions[:2])
|
||||
|
||||
# manual cleanup
|
||||
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
|
||||
set_compilation_config(None)
|
||||
|
||||
output = output.cpu()
|
||||
|
||||
if llama_config.tractable_init:
|
||||
@ -361,7 +350,6 @@ def test_toy_llama():
|
||||
|
||||
@torch.inference_mode
|
||||
def benchmark():
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||
from triton.testing import do_bench
|
||||
|
||||
# similar to llama 3.1-8B
|
||||
@ -387,15 +375,16 @@ def benchmark():
|
||||
|
||||
for piecewise in [False, True]:
|
||||
if piecewise:
|
||||
set_compilation_config(
|
||||
CompilationConfig(
|
||||
use_cudagraph=True,
|
||||
non_cudagraph_ops=["silly.attention"],
|
||||
))
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
non_cudagraph_ops=["silly.attention"],
|
||||
)
|
||||
else:
|
||||
set_compilation_config(None)
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE, )
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config = VllmConfig(compilation_config=compilation_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = LlamaModel(config=llama_config,
|
||||
vllm_config=vllm_config,
|
||||
|
||||
@ -96,31 +96,36 @@ def test_compile_correctness(test_setting: TestSetting):
|
||||
final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \
|
||||
["-tp", str(tp_size)]
|
||||
|
||||
all_args: List[List[str]] = []
|
||||
all_envs: List[Optional[Dict[str, str]]] = []
|
||||
|
||||
for level in [
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
CompilationLevel.PIECEWISE,
|
||||
]:
|
||||
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})
|
||||
all_args.append(final_args + ["-O", str(level)])
|
||||
all_envs.append({})
|
||||
|
||||
# inductor will change the output, so we only compare if the output
|
||||
# is close, not exactly the same.
|
||||
compare_all_settings(
|
||||
model, [final_args] * 2,
|
||||
model,
|
||||
all_args,
|
||||
all_envs,
|
||||
method=method if method != "generate" else "generate_close")
|
||||
all_envs.clear()
|
||||
all_args.clear()
|
||||
|
||||
for level in [
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
CompilationLevel.DYNAMO_AS_IS,
|
||||
CompilationLevel.DYNAMO_ONCE,
|
||||
]:
|
||||
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})
|
||||
all_args.append(final_args + ["-O", str(level)])
|
||||
all_envs.append({})
|
||||
if level != CompilationLevel.DYNAMO_ONCE and not fullgraph:
|
||||
# "DYNAMO_ONCE" will always use fullgraph
|
||||
all_envs[-1][
|
||||
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore
|
||||
|
||||
compare_all_settings(model, [final_args] * 3, all_envs, method=method)
|
||||
compare_all_settings(model, all_args * 3, all_envs, method=method)
|
||||
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
TEST_MODELS = [
|
||||
@ -65,7 +65,6 @@ def check_full_graph_support(model,
|
||||
optimization_level,
|
||||
tp_size=1):
|
||||
# make sure these models can be captured in full graph mode
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
|
||||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
|
||||
|
||||
# The base meta llama uses too much memory.
|
||||
@ -86,6 +85,7 @@ def check_full_graph_support(model,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
disable_custom_all_reduce=True,
|
||||
compilation_config=CompilationConfig(level=optimization_level),
|
||||
**model_kwargs)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
@ -53,9 +52,8 @@ class Relu3(ReLUSquaredActivation):
|
||||
])
|
||||
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
|
||||
default_on: bool):
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
custom_ops=env.split(",")))
|
||||
level=torch_level, custom_ops=env.split(",")))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
assert CustomOp.default_on() == default_on
|
||||
|
||||
|
||||
@ -1,24 +1,47 @@
|
||||
import glob
|
||||
import os
|
||||
import runpy
|
||||
import tempfile
|
||||
|
||||
import depyf
|
||||
|
||||
from vllm.config import CompilationLevel
|
||||
|
||||
# disable custom dispatcher, let Dynamo takes over
|
||||
# all the control
|
||||
os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS)
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with depyf.prepare_debug(temp_dir):
|
||||
cur_dir = os.path.dirname(__file__)
|
||||
parent_dir = os.path.dirname(cur_dir)
|
||||
root_dir = os.path.dirname(parent_dir)
|
||||
example_file = os.path.join(root_dir, "examples",
|
||||
"offline_inference_tpu.py")
|
||||
runpy.run_path(example_file)
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
prompts = [
|
||||
"A robot may not injure a human being",
|
||||
"It is only with the heart that one can see rightly;",
|
||||
"The greatest glory in living lies not in never falling,",
|
||||
]
|
||||
answers = [
|
||||
" or, through inaction, allow a human being to come to harm.",
|
||||
" what is essential is invisible to the eye.",
|
||||
" but in rising every time we fall.",
|
||||
]
|
||||
N = 1
|
||||
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
|
||||
sampling_params = SamplingParams(temperature=0.7,
|
||||
top_p=1.0,
|
||||
n=N,
|
||||
max_tokens=16)
|
||||
|
||||
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
|
||||
# In real workloads, `enforace_eager` should be `False`.
|
||||
|
||||
# disable custom dispatcher, let Dynamo takes over
|
||||
# all the control
|
||||
llm = LLM(model="google/gemma-2b",
|
||||
enforce_eager=True,
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.DYNAMO_AS_IS))
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output, answer in zip(outputs, answers):
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
assert generated_text.startswith(answer)
|
||||
|
||||
compiled_code = sorted(
|
||||
glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
|
||||
|
||||
@ -13,7 +13,9 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
|
||||
def test_custom_dispatcher():
|
||||
compare_two_settings(
|
||||
"google/gemma-2b",
|
||||
arg1=["--enforce-eager"],
|
||||
arg2=["--enforce-eager"],
|
||||
env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)},
|
||||
env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)})
|
||||
arg1=["--enforce-eager", "-O",
|
||||
str(CompilationLevel.DYNAMO_ONCE)],
|
||||
arg2=["--enforce-eager", "-O",
|
||||
str(CompilationLevel.DYNAMO_AS_IS)],
|
||||
env1={},
|
||||
env2={})
|
||||
|
||||
Reference in New Issue
Block a user