OpenAI Compatible Frontend (#116)

This commit is contained in:
Zhuohan Li
2023-05-23 21:39:50 -07:00
committed by GitHub
parent e86717833d
commit 057daef778
20 changed files with 644 additions and 169 deletions

View File

@ -0,0 +1,44 @@
import argparse
import json
import gradio as gr
import requests
def http_bot(prompt):
headers = {"User-Agent": "Cacheflow Client"}
pload = {
"prompt": prompt,
"max_tokens": 128,
}
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"][0]
yield output
def build_demo():
with gr.Blocks() as demo:
gr.Markdown(
"# Cacheflow text completion demo\n"
)
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
inputbox.submit(http_bot, [inputbox], [outputbox])
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8002)
parser.add_argument("--model-url", type=str, default="http://localhost:8001/generate")
args = parser.parse_args()
demo = build_demo()
demo.queue(concurrency_count=100).launch(server_name=args.host,
server_port=args.port,
share=True)

22
examples/openai_client.py Normal file
View File

@ -0,0 +1,22 @@
import openai
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
model = "facebook/opt-125m"
# list models
models = openai.Model.list()
print(models)
# create a completion
stream = True
completion = openai.Completion.create(
model=model, prompt="A robot may not injure a human being", echo=False, n=2,
best_of=3, stream=stream, logprobs=3)
# print the completion
if stream:
for c in completion:
print(c)
else:
print("completion:", completion)

View File

@ -0,0 +1,48 @@
import argparse
import requests
import json
def clear_line(n=1):
LINE_UP = '\033[1A'
LINE_CLEAR = '\x1b[2K'
for i in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True)
def http_request(prompt: str, api_url: str, n: int = 1):
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": n,
"use_beam_search": True,
"temperature": 0.0,
"max_tokens": 16,
}
response = requests.post(api_url, headers=headers, json=pload, stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--n", type=int, default=4)
parser.add_argument("--prompt", type=str, default="San Francisco is a")
args = parser.parse_args()
prompt = args.prompt
api_url = f"http://{args.host}:{args.port}/generate"
n = args.n
print(f"Prompt: {prompt}\n", flush=True)
num_printed_lines = 0
for h in http_request(prompt, api_url, n):
clear_line(num_printed_lines)
num_printed_lines = 0
for i, line in enumerate(h):
num_printed_lines += 1
print(f"Beam candidate {i}: {line}", flush=True)

View File

@ -1,5 +1,4 @@
import argparse
import uuid
from cacheflow import ServerArgs, LLMServer, SamplingParams
@ -20,17 +19,19 @@ def main(args: argparse.Namespace):
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
]
request_id = 0
# Run the server.
while True:
# To test iteration-level scheduling, we add one request at each step.
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
request_id = str(uuid.uuid4().hex[:8])
server.add_request(request_id, prompt, sampling_params)
server.add_request(str(request_id), prompt, sampling_params)
request_id += 1
request_outputs = server.step()
for request_output in request_outputs:
if request_output.done:
if request_output.finished():
print(request_output)
if not (server.has_unfinished_requests() or test_prompts):