Support Batch Completion in Server (#2529)
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
import time
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import sys
|
||||
import pytest
|
||||
@ -17,8 +18,11 @@ pytestmark = pytest.mark.asyncio
|
||||
class ServerRunner:
|
||||
|
||||
def __init__(self, args):
|
||||
env = os.environ.copy()
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
self.proc = subprocess.Popen(
|
||||
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
@ -58,7 +62,8 @@ def server():
|
||||
"--dtype",
|
||||
"bfloat16", # use half precision for speed and memory savings in CI environment
|
||||
"--max-model-len",
|
||||
"8192"
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
])
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
@ -199,5 +204,51 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
|
||||
assert "".join(chunks) == output
|
||||
|
||||
|
||||
async def test_batch_completions(server, client: openai.AsyncOpenAI):
|
||||
# test simple list
|
||||
batch = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(batch.choices) == 2
|
||||
assert batch.choices[0].text == batch.choices[1].text
|
||||
|
||||
# test n = 2
|
||||
batch = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
n=2,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
# NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client.
|
||||
use_beam_search=True),
|
||||
)
|
||||
assert len(batch.choices) == 4
|
||||
assert batch.choices[0].text != batch.choices[
|
||||
1].text, "beam search should be different"
|
||||
assert batch.choices[0].text == batch.choices[
|
||||
2].text, "two copies of the same prompt should be the same"
|
||||
assert batch.choices[1].text == batch.choices[
|
||||
3].text, "two copies of the same prompt should be the same"
|
||||
|
||||
# test streaming
|
||||
batch = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
texts = [""] * 2
|
||||
async for chunk in batch:
|
||||
assert len(chunk.choices) == 1
|
||||
choice = chunk.choices[0]
|
||||
texts[choice.index] += choice.text
|
||||
assert texts[0] == texts[1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user