[Bugfix] Fix PP for Multi-Step (#8887)
This commit is contained in:
committed by
GitHub
parent
39d3f8d94f
commit
19d02ff938
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import signal
|
||||
@ -7,7 +8,7 @@ import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
@ -476,7 +477,8 @@ async def completions_with_server_args(
|
||||
server_cli_args: List[str],
|
||||
num_logprobs: Optional[int],
|
||||
max_wait_seconds: int = 240,
|
||||
) -> Completion:
|
||||
max_tokens: Union[int, list] = 5,
|
||||
) -> List[Completion]:
|
||||
'''Construct a remote OpenAI server, obtain an async client to the
|
||||
server & invoke the completions API to obtain completions.
|
||||
|
||||
@ -487,37 +489,49 @@ async def completions_with_server_args(
|
||||
num_logprobs: Number of logprobs to report (or `None`)
|
||||
max_wait_seconds: timeout interval for bringing up server.
|
||||
Default: 240sec
|
||||
max_tokens: max_tokens value for each of the given input prompts.
|
||||
if only one max_token value is given, the same value is used
|
||||
for all the prompts.
|
||||
|
||||
Returns:
|
||||
OpenAI Completion instance
|
||||
'''
|
||||
|
||||
if isinstance(max_tokens, int):
|
||||
max_tokens = [max_tokens] * len(prompts)
|
||||
|
||||
assert len(max_tokens) == len(prompts)
|
||||
|
||||
outputs = None
|
||||
max_wait_seconds = 240 * 3 # 240 is default
|
||||
with RemoteOpenAIServer(model_name,
|
||||
server_cli_args,
|
||||
max_wait_seconds=max_wait_seconds) as server:
|
||||
client = server.get_async_client()
|
||||
outputs = await client.completions.create(model=model_name,
|
||||
prompt=prompts,
|
||||
temperature=0,
|
||||
stream=False,
|
||||
max_tokens=5,
|
||||
logprobs=num_logprobs)
|
||||
outputs = [ client.completions.create(model=model_name,
|
||||
prompt=[p],
|
||||
temperature=0,
|
||||
stream=False,
|
||||
max_tokens=max_tok,
|
||||
logprobs=num_logprobs) \
|
||||
for p, max_tok in zip(prompts, max_tokens) ]
|
||||
outputs = await asyncio.gather(*outputs)
|
||||
|
||||
assert outputs is not None, "Completion API call failed."
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def get_client_text_generations(completions: Completion) -> List[str]:
|
||||
def get_client_text_generations(completions: List[Completion]) -> List[str]:
|
||||
'''Extract generated tokens from the output of a
|
||||
request made to an Open-AI-protocol completions endpoint.
|
||||
'''
|
||||
return [x.text for x in completions.choices]
|
||||
assert all([len(x.choices) == 1 for x in completions])
|
||||
return [x.choices[0].text for x in completions]
|
||||
|
||||
|
||||
def get_client_text_logprob_generations(
|
||||
completions: Completion) -> List[TextTextLogprobs]:
|
||||
completions: List[Completion]) -> List[TextTextLogprobs]:
|
||||
'''Operates on the output of a request made to an Open-AI-protocol
|
||||
completions endpoint; obtains top-rank logprobs for each token in
|
||||
each :class:`SequenceGroup`
|
||||
@ -526,4 +540,4 @@ def get_client_text_logprob_generations(
|
||||
text = ''.join(text_generations)
|
||||
return [(text_generations, text,
|
||||
(None if x.logprobs is None else x.logprobs.top_logprobs))
|
||||
for x in completions.choices]
|
||||
for completion in completions for x in completion.choices]
|
||||
|
||||
Reference in New Issue
Block a user