[Bugfix] Fix PP for Multi-Step (#8887)

This commit is contained in:
Varun Sundar Rabindranath
2024-09-28 11:52:46 -04:00
committed by GitHub
parent 39d3f8d94f
commit 19d02ff938
5 changed files with 130 additions and 15 deletions

View File

@ -1,3 +1,4 @@
import asyncio
import functools
import os
import signal
@ -7,7 +8,7 @@ import time
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union
import openai
import pytest
@ -476,7 +477,8 @@ async def completions_with_server_args(
server_cli_args: List[str],
num_logprobs: Optional[int],
max_wait_seconds: int = 240,
) -> Completion:
max_tokens: Union[int, list] = 5,
) -> List[Completion]:
'''Construct a remote OpenAI server, obtain an async client to the
server & invoke the completions API to obtain completions.
@ -487,37 +489,49 @@ async def completions_with_server_args(
num_logprobs: Number of logprobs to report (or `None`)
max_wait_seconds: timeout interval for bringing up server.
Default: 240sec
max_tokens: max_tokens value for each of the given input prompts.
if only one max_token value is given, the same value is used
for all the prompts.
Returns:
OpenAI Completion instance
'''
if isinstance(max_tokens, int):
max_tokens = [max_tokens] * len(prompts)
assert len(max_tokens) == len(prompts)
outputs = None
max_wait_seconds = 240 * 3 # 240 is default
with RemoteOpenAIServer(model_name,
server_cli_args,
max_wait_seconds=max_wait_seconds) as server:
client = server.get_async_client()
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5,
logprobs=num_logprobs)
outputs = [ client.completions.create(model=model_name,
prompt=[p],
temperature=0,
stream=False,
max_tokens=max_tok,
logprobs=num_logprobs) \
for p, max_tok in zip(prompts, max_tokens) ]
outputs = await asyncio.gather(*outputs)
assert outputs is not None, "Completion API call failed."
return outputs
def get_client_text_generations(completions: Completion) -> List[str]:
def get_client_text_generations(completions: List[Completion]) -> List[str]:
'''Extract generated tokens from the output of a
request made to an Open-AI-protocol completions endpoint.
'''
return [x.text for x in completions.choices]
assert all([len(x.choices) == 1 for x in completions])
return [x.choices[0].text for x in completions]
def get_client_text_logprob_generations(
completions: Completion) -> List[TextTextLogprobs]:
completions: List[Completion]) -> List[TextTextLogprobs]:
'''Operates on the output of a request made to an Open-AI-protocol
completions endpoint; obtains top-rank logprobs for each token in
each :class:`SequenceGroup`
@ -526,4 +540,4 @@ def get_client_text_logprob_generations(
text = ''.join(text_generations)
return [(text_generations, text,
(None if x.logprobs is None else x.logprobs.top_logprobs))
for x in completions.choices]
for completion in completions for x in completion.choices]