[Misc] Remove unnecessary ModelRunner imports (#4703)
This commit is contained in:
@ -9,7 +9,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
class MockLogitsProcessor(LogitsProcessor):
|
||||
@ -30,21 +30,15 @@ class MockLogitsProcessor(LogitsProcessor):
|
||||
|
||||
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]:
|
||||
batch_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, vocab_size),
|
||||
1e-2,
|
||||
dtype=input_tensor.dtype)
|
||||
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
|
||||
model_runner = ModelRunner(model_config=None,
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
return input_tensor, fake_logits, logits_processor, model_runner
|
||||
return input_tensor, fake_logits, logits_processor
|
||||
|
||||
|
||||
RANDOM_SEEDS = list(range(128))
|
||||
@ -59,8 +53,7 @@ def test_logits_processors(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, logits_processor, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)
|
||||
|
||||
# This sample logits processor gives infinite score to the i-th token,
|
||||
# where i is the length of the input sequence.
|
||||
@ -87,8 +80,8 @@ def test_logits_processors(seed: int, device: str):
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
logits_processor_output = logits_processor(
|
||||
embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
@ -99,5 +92,3 @@ def test_logits_processors(seed: int, device: str):
|
||||
fake_logits *= logits_processor.scale
|
||||
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1],
|
||||
1e-4)
|
||||
|
||||
del model_runner
|
||||
|
||||
Reference in New Issue
Block a user