Support per-request seed (#2514)

This commit is contained in:
Nick Hill
2024-02-21 11:47:00 -08:00
committed by GitHub
parent dc903e70ac
commit 7d2dcce175
10 changed files with 296 additions and 91 deletions

View File

@ -1,10 +1,11 @@
import random
from typing import Tuple
from typing import Tuple, List
from unittest.mock import patch
import pytest
import torch
from transformers import GenerationConfig, GenerationMixin
from typing import Optional
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
@ -46,6 +47,34 @@ CUDA_DEVICES = [
]
def _do_sample(
batch_size: int,
input_tensor: torch.Tensor,
sampler: MockLogitsSampler,
model_runner: ModelRunner,
sampling_params: SamplingParams,
):
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
return sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
@ -55,25 +84,9 @@ def test_sampler_all_greedy(seed: int, device: str):
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0, ),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
@ -94,28 +107,13 @@ def test_sampler_all_random(seed: int, device: str):
for i in range(batch_size):
fake_logits[i, i] = 1e2
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
@ -123,6 +121,58 @@ def test_sampler_all_random(seed: int, device: str):
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
for i in range(batch_size):
fake_logits[i, i] = 1e2
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
assert first_sampler_output == second_sampler_output
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, input_tensor, sampler, model_runner,
sampling_params)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
batch_size)
seq_group_metadata_list = []
expected_tokens = []
expected_tokens: List[Optional[List[int]]] = []
prompt_lens = []
for i in range(batch_size):
n = 1
sampling_type = random.randint(0, 2)
expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
elif sampling_type == 1:
expected = [torch.argmax(fake_logits[i], dim=-1).item()]
elif sampling_type in (1, 2):
n = random.randint(1, 10)
sampling_params = SamplingParams(
temperature=random.random() + 0.1,
@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
n=n,
presence_penalty=random.randint(0, 1),
)
if sampling_type == 2:
sampling_params.seed = random.randint(0, 10000)
else:
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected = list(range(i, i + n))
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected_tokens.append(i + idx)
expected_tokens.append(expected)
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue
for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens
def test_sampling(model_runner: ModelRunner):
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for i, (sequence_output, metadata) in enumerate(
zip(sampler_output, seq_group_metadata_list)):
if metadata.sampling_params.use_beam_search:
continue
if metadata.sampling_params.seed is not None \
and expected_tokens[i] is None:
# Record seeded random result to compare with results of second invocation
expected_tokens[i] = [
nth_output.output_token
for nth_output in sequence_output.samples
]
continue
for n, nth_output in enumerate(sequence_output.samples):
if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None:
# Ensure exact matches for greedy or random with seed
assert nth_output.output_token == expected_tokens[i][n]
else:
# For non-seeded random check that one of the high-logit tokens were chosen
assert nth_output.output_token in expected_tokens[i]
# Test batch
test_sampling(model_runner)
# Shuffle the batch and resample
target_index = list(range(batch_size))
for list_to_shuffle in (target_index, seq_group_metadata_list,
expected_tokens, prompt_lens):
random.Random(seed).shuffle(list_to_shuffle)
target_index = torch.tensor(target_index)
input_tensor.data = input_tensor.index_select(0, target_index)
fake_logits.data = fake_logits.index_select(0, target_index)
# This time, results of seeded random samples will be compared with the corresponding
# sample in the pre-shuffled batch
test_sampling(model_runner)
del model_runner