[torch.compile] remove compilation_context and simplify code (#10838)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-12-02 22:19:02 -08:00
committed by GitHub
parent 21fe7b481a
commit dc5ce861bf
14 changed files with 128 additions and 143 deletions

View File

@ -7,7 +7,6 @@ import torch
from torch import nn
from torch.library import Library
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
@ -81,6 +80,7 @@ def test_simple_piecewise_compile():
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
))
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
@ -96,11 +96,10 @@ def test_simple_piecewise_compile():
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
with set_compile_context([1, 2]):
model(inputs)
model(inputs)
model(torch.randn(2).cuda())
model(torch.randn(1).cuda())
model(torch.randn(2).cuda())
model(torch.randn(1).cuda())
input = torch.zeros(2).cuda()
global global_counter