[Misc] Remove unnecessary ModelRunner imports (#4703)

This commit is contained in:
Woosuk Kwon
2024-05-09 00:17:17 -07:00
committed by GitHub
parent f12b20decc
commit 190bc838e1
2 changed files with 31 additions and 73 deletions

View File

@ -9,7 +9,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
from vllm.utils import is_pin_memory_available
class MockLogitsProcessor(LogitsProcessor):
@ -30,21 +30,15 @@ class MockLogitsProcessor(LogitsProcessor):
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]:
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=input_tensor.dtype)
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
return input_tensor, fake_logits, logits_processor, model_runner
return input_tensor, fake_logits, logits_processor
RANDOM_SEEDS = list(range(128))
@ -59,8 +53,7 @@ def test_logits_processors(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, logits_processor, model_runner = _prepare_test(
batch_size)
input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
@ -87,8 +80,8 @@ def test_logits_processors(seed: int, device: str):
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
device=device,
pin_memory=is_pin_memory_available())
logits_processor_output = logits_processor(
embedding=None,
hidden_states=input_tensor,
@ -99,5 +92,3 @@ def test_logits_processors(seed: int, device: str):
fake_logits *= logits_processor.scale
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1],
1e-4)
del model_runner