Support Batch Completion in Server (#2529)

This commit is contained in:
Simon Mo
2024-01-24 17:11:07 -08:00
committed by GitHub
parent 223c19224b
commit 3a7dd7e367
2 changed files with 214 additions and 104 deletions

View File

@ -1,5 +1,6 @@
import time
import os
import subprocess
import time
import sys
import pytest
@ -17,8 +18,11 @@ pytestmark = pytest.mark.asyncio
class ServerRunner:
def __init__(self, args):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
@ -58,7 +62,8 @@ def server():
"--dtype",
"bfloat16", # use half precision for speed and memory savings in CI environment
"--max-model-len",
"8192"
"8192",
"--enforce-eager",
])
ray.get(server_runner.ready.remote())
yield server_runner
@ -199,5 +204,51 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
assert "".join(chunks) == output
async def test_batch_completions(server, client: openai.AsyncOpenAI):
# test simple list
batch = await client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text
# test n = 2
batch = await client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
# test streaming
batch = await client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]
if __name__ == "__main__":
pytest.main([__file__])