[torch.compile] remove compilation_context and simplify code (#10838)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-12-02 22:19:02 -08:00
committed by GitHub
parent 21fe7b481a
commit dc5ce861bf
14 changed files with 128 additions and 143 deletions

View File

@ -4,12 +4,12 @@ from typing import List
import pytest
import torch
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import _get_graph_batch_size
BATCH_SIZES = [1, 4, 16, 64, 256]
@ -548,7 +548,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# accordingly.
graph_batch_size = _get_graph_batch_size(expanded_batch_size)
graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size)
cuda_graph_pad_size = graph_batch_size - expanded_batch_size
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
padded_encoder_seq_lens = encoder_seq_lens + list(

View File

@ -3,13 +3,14 @@ from typing import List
import pytest
import torch
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
from vllm.worker.model_runner import ModelRunner
def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
@ -176,7 +177,7 @@ def test_prepare_decode_cuda_graph(batch_size):
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
assert len(slot_mapping) == len(input_tokens)
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.num_prefills == 0