[Core][Bugfix] Support prompt_logprobs returned with speculative decoding (#8047)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
Travis Johnson
2024-09-24 18:29:56 -06:00
committed by GitHub
parent 13f9f7a3d0
commit 01b6f9e1f0
14 changed files with 486 additions and 128 deletions

View File

@ -1,13 +1,16 @@
from itertools import cycle
from typing import List, Optional, Tuple
from typing import List, Optional, Sequence, Tuple, Union
import pytest
from vllm import LLM, SamplingParams
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import PromptLogprobs, SampleLogprobs
from ...conftest import cleanup
from ...models.utils import check_logprobs_close, check_outputs_equal
from ...models.utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
check_logprobs_close, check_outputs_equal)
from ...utils import RemoteOpenAIServer
PROMPTS = [
@ -81,45 +84,77 @@ def get_output_from_llm_generator(
return tokens, token_ids, acceptance_rate
def run_logprob_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
max_output_len: int,
seed: Optional[int] = 0,
temperature: float = 0.0,
logprobs: int = 1):
org_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**baseline_llm_kwargs,
}
def check_logprobs_correctness(
spec_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
baseline_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
disable_logprobs: bool = False,
):
"""Compare sampled and prompt logprobs between baseline and spec decoding
"""
if not disable_logprobs:
return check_logprobs_close(
outputs_0_lst=baseline_outputs,
outputs_1_lst=spec_outputs,
name_0="org",
name_1="sd",
)
sd_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}
# Check correctness when disable_logprobs == True
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
# Check generated token logprobs.
spec_logprobs = spec_output[2]
baseline_logprobs = baseline_output[2]
_check_logprobs_when_output_disabled(spec_logprobs,
baseline_logprobs,
is_prompt_logprobs=False)
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
# Check prompt logprobs too, if they exist
if len(baseline_output) == 4:
assert len(spec_output) == 4
spec_prompt_logprobs = spec_output[3]
baseline_prompt_logprobs = baseline_output[3]
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
baseline_prompt_logprobs,
is_prompt_logprobs=True)
sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_output_len,
seed=seed,
logprobs=logprobs)
with vllm_runner(**org_args) as vllm_model:
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
def _check_logprobs_when_output_disabled(
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
is_prompt_logprobs: bool = False,
):
# Prompt logprobs are optional
if is_prompt_logprobs and baseline_logprobs is None:
assert spec_logprobs is None
return
with vllm_runner(**sd_args) as vllm_model:
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
assert spec_logprobs is not None
assert baseline_logprobs is not None
assert len(spec_logprobs) == len(baseline_logprobs)
check_logprobs_close(outputs_0_lst=org_outputs,
outputs_1_lst=sd_outputs,
name_0="org",
name_1="sd")
# For each generated position of the sequence.
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
zip(spec_logprobs, baseline_logprobs)):
# First prompt logprob is expected to be None
if is_prompt_logprobs and baseline_pos_logprobs is None:
assert spec_pos_logprobs is None
assert pos == 0
continue
assert spec_pos_logprobs is not None
assert baseline_pos_logprobs is not None
# When disabled, the 1 logprob is returned with dummy values for the
# score and rank, but the token id should match the baseline model
assert len(spec_pos_logprobs) == 1
(spec_pos_logprob_token_id,
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
assert spec_pos_logprob.rank == -1
assert spec_pos_logprob.logprob == 0.0
assert spec_pos_logprob_token_id in baseline_pos_logprobs
def run_equality_correctness_test(
@ -135,7 +170,10 @@ def run_equality_correctness_test(
disable_seed: bool = False,
ignore_eos: bool = True,
ensure_all_accepted: bool = False,
expected_acceptance_rate: Optional[float] = None):
expected_acceptance_rate: Optional[float] = None,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
disable_logprobs: bool = False):
org_args = {
**common_llm_kwargs,
@ -157,10 +195,12 @@ def run_equality_correctness_test(
sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_output_len,
seed=seed,
ignore_eos=ignore_eos)
ignore_eos=ignore_eos,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs)
with vllm_runner(**org_args) as vllm_model:
org_outputs = vllm_model.generate(prompts, sampling_params)
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
with vllm_runner(**sd_args) as vllm_model:
if ensure_all_accepted or expected_acceptance_rate is not None:
@ -169,7 +209,7 @@ def run_equality_correctness_test(
'prometheus']
stat_logger.local_interval = -100
sd_outputs = vllm_model.generate(prompts, sampling_params)
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
if ensure_all_accepted or expected_acceptance_rate is not None:
acceptance_rate = (stat_logger.metrics.
@ -185,11 +225,16 @@ def run_equality_correctness_test(
if expected_acceptance_rate is not None:
assert acceptance_rate >= expected_acceptance_rate - 1e-2
check_outputs_equal(outputs_0_lst=org_outputs,
outputs_1_lst=sd_outputs,
# Only pass token entries, not the logprobs
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
outputs_1_lst=[out[0:2] for out in sd_outputs],
name_0="org",
name_1="sd")
# Check logprobs if requested
if logprobs is not None or prompt_logprobs is not None:
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)
def run_equality_correctness_test_tp(model,
common_llm_kwargs,