[Core][Bugfix] Support prompt_logprobs returned with speculative decoding (#8047)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
@ -1,13 +1,16 @@
|
||||
from itertools import cycle
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import PromptLogprobs, SampleLogprobs
|
||||
|
||||
from ...conftest import cleanup
|
||||
from ...models.utils import check_logprobs_close, check_outputs_equal
|
||||
from ...models.utils import (TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs,
|
||||
check_logprobs_close, check_outputs_equal)
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
PROMPTS = [
|
||||
@ -81,45 +84,77 @@ def get_output_from_llm_generator(
|
||||
return tokens, token_ids, acceptance_rate
|
||||
|
||||
|
||||
def run_logprob_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int,
|
||||
max_output_len: int,
|
||||
seed: Optional[int] = 0,
|
||||
temperature: float = 0.0,
|
||||
logprobs: int = 1):
|
||||
org_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**baseline_llm_kwargs,
|
||||
}
|
||||
def check_logprobs_correctness(
|
||||
spec_outputs: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs]],
|
||||
baseline_outputs: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs]],
|
||||
disable_logprobs: bool = False,
|
||||
):
|
||||
"""Compare sampled and prompt logprobs between baseline and spec decoding
|
||||
"""
|
||||
if not disable_logprobs:
|
||||
return check_logprobs_close(
|
||||
outputs_0_lst=baseline_outputs,
|
||||
outputs_1_lst=spec_outputs,
|
||||
name_0="org",
|
||||
name_1="sd",
|
||||
)
|
||||
|
||||
sd_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**test_llm_kwargs,
|
||||
}
|
||||
# Check correctness when disable_logprobs == True
|
||||
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
|
||||
# Check generated token logprobs.
|
||||
spec_logprobs = spec_output[2]
|
||||
baseline_logprobs = baseline_output[2]
|
||||
_check_logprobs_when_output_disabled(spec_logprobs,
|
||||
baseline_logprobs,
|
||||
is_prompt_logprobs=False)
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||
# Check prompt logprobs too, if they exist
|
||||
if len(baseline_output) == 4:
|
||||
assert len(spec_output) == 4
|
||||
spec_prompt_logprobs = spec_output[3]
|
||||
baseline_prompt_logprobs = baseline_output[3]
|
||||
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
|
||||
baseline_prompt_logprobs,
|
||||
is_prompt_logprobs=True)
|
||||
|
||||
sampling_params = SamplingParams(temperature=temperature,
|
||||
max_tokens=max_output_len,
|
||||
seed=seed,
|
||||
logprobs=logprobs)
|
||||
|
||||
with vllm_runner(**org_args) as vllm_model:
|
||||
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
def _check_logprobs_when_output_disabled(
|
||||
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
|
||||
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
|
||||
is_prompt_logprobs: bool = False,
|
||||
):
|
||||
# Prompt logprobs are optional
|
||||
if is_prompt_logprobs and baseline_logprobs is None:
|
||||
assert spec_logprobs is None
|
||||
return
|
||||
|
||||
with vllm_runner(**sd_args) as vllm_model:
|
||||
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
assert spec_logprobs is not None
|
||||
assert baseline_logprobs is not None
|
||||
assert len(spec_logprobs) == len(baseline_logprobs)
|
||||
|
||||
check_logprobs_close(outputs_0_lst=org_outputs,
|
||||
outputs_1_lst=sd_outputs,
|
||||
name_0="org",
|
||||
name_1="sd")
|
||||
# For each generated position of the sequence.
|
||||
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
||||
zip(spec_logprobs, baseline_logprobs)):
|
||||
|
||||
# First prompt logprob is expected to be None
|
||||
if is_prompt_logprobs and baseline_pos_logprobs is None:
|
||||
assert spec_pos_logprobs is None
|
||||
assert pos == 0
|
||||
continue
|
||||
|
||||
assert spec_pos_logprobs is not None
|
||||
assert baseline_pos_logprobs is not None
|
||||
|
||||
# When disabled, the 1 logprob is returned with dummy values for the
|
||||
# score and rank, but the token id should match the baseline model
|
||||
assert len(spec_pos_logprobs) == 1
|
||||
(spec_pos_logprob_token_id,
|
||||
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
|
||||
assert spec_pos_logprob.rank == -1
|
||||
assert spec_pos_logprob.logprob == 0.0
|
||||
assert spec_pos_logprob_token_id in baseline_pos_logprobs
|
||||
|
||||
|
||||
def run_equality_correctness_test(
|
||||
@ -135,7 +170,10 @@ def run_equality_correctness_test(
|
||||
disable_seed: bool = False,
|
||||
ignore_eos: bool = True,
|
||||
ensure_all_accepted: bool = False,
|
||||
expected_acceptance_rate: Optional[float] = None):
|
||||
expected_acceptance_rate: Optional[float] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
disable_logprobs: bool = False):
|
||||
|
||||
org_args = {
|
||||
**common_llm_kwargs,
|
||||
@ -157,10 +195,12 @@ def run_equality_correctness_test(
|
||||
sampling_params = SamplingParams(temperature=temperature,
|
||||
max_tokens=max_output_len,
|
||||
seed=seed,
|
||||
ignore_eos=ignore_eos)
|
||||
ignore_eos=ignore_eos,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
|
||||
with vllm_runner(**org_args) as vllm_model:
|
||||
org_outputs = vllm_model.generate(prompts, sampling_params)
|
||||
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
with vllm_runner(**sd_args) as vllm_model:
|
||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||
@ -169,7 +209,7 @@ def run_equality_correctness_test(
|
||||
'prometheus']
|
||||
stat_logger.local_interval = -100
|
||||
|
||||
sd_outputs = vllm_model.generate(prompts, sampling_params)
|
||||
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||
acceptance_rate = (stat_logger.metrics.
|
||||
@ -185,11 +225,16 @@ def run_equality_correctness_test(
|
||||
if expected_acceptance_rate is not None:
|
||||
assert acceptance_rate >= expected_acceptance_rate - 1e-2
|
||||
|
||||
check_outputs_equal(outputs_0_lst=org_outputs,
|
||||
outputs_1_lst=sd_outputs,
|
||||
# Only pass token entries, not the logprobs
|
||||
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
|
||||
outputs_1_lst=[out[0:2] for out in sd_outputs],
|
||||
name_0="org",
|
||||
name_1="sd")
|
||||
|
||||
# Check logprobs if requested
|
||||
if logprobs is not None or prompt_logprobs is not None:
|
||||
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)
|
||||
|
||||
|
||||
def run_equality_correctness_test_tp(model,
|
||||
common_llm_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user