[torch.compile] remove compilation_context and simplify code (#10838)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@ -4,12 +4,12 @@ from typing import List
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
|
||||
BATCH_SIZES = [1, 4, 16, 64, 256]
|
||||
|
||||
@ -548,7 +548,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
|
||||
# With CUDA Graph capture and replay enabled, the decoder and encoder
|
||||
# input sequences will be padded. Create the expected padded tensors
|
||||
# accordingly.
|
||||
graph_batch_size = _get_graph_batch_size(expanded_batch_size)
|
||||
graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size)
|
||||
cuda_graph_pad_size = graph_batch_size - expanded_batch_size
|
||||
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
|
||||
padded_encoder_seq_lens = encoder_seq_lens + list(
|
||||
|
||||
@ -3,13 +3,14 @@ from typing import List
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
|
||||
|
||||
def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
|
||||
@ -176,7 +177,7 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
||||
expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list))
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.num_prefills == 0
|
||||
|
||||
Reference in New Issue
Block a user