Support per-request seed (#2514)
This commit is contained in:
@ -1,10 +1,11 @@
|
||||
import random
|
||||
from typing import Tuple
|
||||
from typing import Tuple, List
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import GenerationConfig, GenerationMixin
|
||||
from typing import Optional
|
||||
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
@ -46,6 +47,34 @@ CUDA_DEVICES = [
|
||||
]
|
||||
|
||||
|
||||
def _do_sample(
|
||||
batch_size: int,
|
||||
input_tensor: torch.Tensor,
|
||||
sampler: MockLogitsSampler,
|
||||
model_runner: ModelRunner,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
return sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_greedy(seed: int, device: str):
|
||||
@ -55,25 +84,9 @@ def test_sampler_all_greedy(seed: int, device: str):
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0, ),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
model_runner, sampling_params)
|
||||
expected = torch.argmax(fake_logits, dim=-1)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
@ -94,28 +107,13 @@ def test_sampler_all_random(seed: int, device: str):
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
model_runner, sampling_params)
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
@ -123,6 +121,58 @@ def test_sampler_all_random(seed: int, device: str):
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random_seed(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
model_runner, sampling_params)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
model_runner, sampling_params)
|
||||
|
||||
second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
model_runner, sampling_params)
|
||||
|
||||
assert first_sampler_output == second_sampler_output
|
||||
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_beam(seed: int, device: str):
|
||||
@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=0,
|
||||
best_of=2,
|
||||
use_beam_search=True,
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
best_of=2,
|
||||
use_beam_search=True,
|
||||
)
|
||||
_do_sample(batch_size, input_tensor, sampler, model_runner,
|
||||
sampling_params)
|
||||
# no assertion here as I am not sure how to determine whether
|
||||
# the outputs are expected - in other words, this just tests
|
||||
# whether there are no exceptions in the sampler
|
||||
@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
expected_tokens = []
|
||||
expected_tokens: List[Optional[List[int]]] = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
n = 1
|
||||
sampling_type = random.randint(0, 2)
|
||||
expected: Optional[List[int]] = None
|
||||
sampling_type = random.randint(0, 3)
|
||||
if sampling_type == 0:
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
elif sampling_type == 1:
|
||||
expected = [torch.argmax(fake_logits[i], dim=-1).item()]
|
||||
elif sampling_type in (1, 2):
|
||||
n = random.randint(1, 10)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=random.random() + 0.1,
|
||||
@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
n=n,
|
||||
presence_penalty=random.randint(0, 1),
|
||||
)
|
||||
if sampling_type == 2:
|
||||
sampling_params.seed = random.randint(0, 10000)
|
||||
else:
|
||||
for idx in range(n):
|
||||
fake_logits[i, i + idx] = 1e2
|
||||
expected = list(range(i, i + n))
|
||||
else:
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
use_beam_search=True,
|
||||
best_of=2)
|
||||
for idx in range(n):
|
||||
fake_logits[i, i + idx] = 1e2
|
||||
expected_tokens.append(i + idx)
|
||||
expected_tokens.append(expected)
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
||||
continue
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token in expected_tokens
|
||||
def test_sampling(model_runner: ModelRunner):
|
||||
sampling_metadata = model_runner._prepare_sample(
|
||||
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
for i, (sequence_output, metadata) in enumerate(
|
||||
zip(sampler_output, seq_group_metadata_list)):
|
||||
if metadata.sampling_params.use_beam_search:
|
||||
continue
|
||||
|
||||
if metadata.sampling_params.seed is not None \
|
||||
and expected_tokens[i] is None:
|
||||
# Record seeded random result to compare with results of second invocation
|
||||
expected_tokens[i] = [
|
||||
nth_output.output_token
|
||||
for nth_output in sequence_output.samples
|
||||
]
|
||||
continue
|
||||
|
||||
for n, nth_output in enumerate(sequence_output.samples):
|
||||
if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None:
|
||||
# Ensure exact matches for greedy or random with seed
|
||||
assert nth_output.output_token == expected_tokens[i][n]
|
||||
else:
|
||||
# For non-seeded random check that one of the high-logit tokens were chosen
|
||||
assert nth_output.output_token in expected_tokens[i]
|
||||
|
||||
# Test batch
|
||||
test_sampling(model_runner)
|
||||
|
||||
# Shuffle the batch and resample
|
||||
target_index = list(range(batch_size))
|
||||
for list_to_shuffle in (target_index, seq_group_metadata_list,
|
||||
expected_tokens, prompt_lens):
|
||||
random.Random(seed).shuffle(list_to_shuffle)
|
||||
target_index = torch.tensor(target_index)
|
||||
input_tensor.data = input_tensor.index_select(0, target_index)
|
||||
fake_logits.data = fake_logits.index_select(0, target_index)
|
||||
|
||||
# This time, results of seeded random samples will be compared with the corresponding
|
||||
# sample in the pre-shuffled batch
|
||||
test_sampling(model_runner)
|
||||
|
||||
del model_runner
|
||||
|
||||
|
||||
Reference in New Issue
Block a user