[6/N] torch.compile rollout to users (#10437)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-19 10:09:03 -08:00
committed by GitHub
parent fd9f124971
commit 803f37eaaa
15 changed files with 129 additions and 141 deletions

View File

@ -1,5 +0,0 @@
{
"use_cudagraph": true,
"non_cudagraph_ops": ["silly.attention"],
"cudagraph_copy_inputs": true
}

View File

@ -2,7 +2,6 @@
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
import os
import torch
from torch import nn
@ -11,7 +10,7 @@ from torch.library import Library
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationLevel, VllmConfig
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op
@ -77,12 +76,12 @@ class SillyModel(nn.Module):
def test_simple_piecewise_compile():
directory = os.path.dirname(__file__)
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
vllm_config = VllmConfig()
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
cudagraph_copy_inputs=True,
))
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
@ -109,6 +108,3 @@ def test_simple_piecewise_compile():
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
# clean up to avoid side effects for other tests
del os.environ["VLLM_TORCH_COMPILE_CONFIG"]

View File

@ -6,7 +6,6 @@ This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
"""
import os
from dataclasses import dataclass
from typing import Optional, Tuple
@ -18,7 +17,7 @@ from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_compilation_config, set_current_vllm_config
from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
@ -254,23 +253,17 @@ def run_model(llama_config,
split_attn: bool = False) -> torch.Tensor:
if use_compile:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
CompilationLevel.PIECEWISE)
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
)
if split_attn:
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
))
else:
set_compilation_config(CompilationConfig(use_cudagraph=True, ))
compilation_config.non_cudagraph_ops = ["silly.attention"]
else:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
CompilationLevel.NO_COMPILATION)
set_compilation_config(None)
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )
vllm_config = VllmConfig()
vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
@ -288,10 +281,6 @@ def run_model(llama_config,
input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])
# manual cleanup
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
set_compilation_config(None)
output = output.cpu()
if llama_config.tractable_init:
@ -361,7 +350,6 @@ def test_toy_llama():
@torch.inference_mode
def benchmark():
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
from triton.testing import do_bench
# similar to llama 3.1-8B
@ -387,15 +375,16 @@ def benchmark():
for piecewise in [False, True]:
if piecewise:
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
))
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
)
else:
set_compilation_config(None)
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, )
vllm_config = VllmConfig()
vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,