Add script for benchmarking serving throughput (#145)

This commit is contained in:
Woosuk Kwon
2023-06-14 19:55:38 -07:00
committed by GitHub
parent da5ddcd544
commit 311490a720
10 changed files with 421 additions and 415 deletions

View File

@ -1,3 +1,4 @@
"""Benchmark offline inference throughput."""
import argparse
import json
import random
@ -5,14 +6,29 @@ import time
from typing import List, Tuple
from cacheflow import LLM, SamplingParams
from transformers import PreTrainedTokenizerBase
import torch
from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM,
PreTrainedTokenizerBase)
from tqdm import tqdm
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
config = AutoConfig.from_pretrained(model_name)
if config.model_type == "llama":
# A workaround for potential protobuf errors.
model_name = "hf-internal-testing/llama-tokenizer"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
return AutoTokenizer.from_pretrained(model_name)
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[List[int], int]]:
) -> List[Tuple[str, int, int]]:
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
@ -35,45 +51,52 @@ def sample_requests(
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
tokenized_dataset.append((prompt_token_ids[i], output_len))
# Filter out if the prompt length + output length is greater than 2048.
tokenized_dataset = [
(prompt_token_ids, output_len)
for prompt_token_ids, output_len in tokenized_dataset
if len(prompt_token_ids) + output_len <= 2048
]
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Filter out too long sequences.
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
prompt_len = len(prompt_token_ids)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
# Sample the requests.
sampled_requests = random.sample(tokenized_dataset, num_requests)
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
def run_cacheflow(
requests: List[Tuple[str, int, int]],
model: str,
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
) -> float:
llm = LLM(
model=args.model,
tensor_parallel_size=args.tensor_parallel_size,
seed=args.seed,
model=model,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
)
tokenizer = llm.get_tokenizer()
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
# Add the requests to the server.
for prompt_token_ids, output_len in requests:
for prompt, _, output_len in requests:
sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
)
# FIXME(woosuk): Do not use internal method.
llm._add_request(
prompt=None,
prompt_token_ids=prompt_token_ids,
prompt=prompt,
sampling_params=sampling_params,
)
@ -81,17 +104,95 @@ def main(args: argparse.Namespace):
# FIXME(woosuk): Do use internal method.
llm._run_server(use_tqdm=True)
end = time.time()
return end - start
def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
use_beam_search: bool,
max_batch_size: int,
) -> float:
assert not use_beam_search
tokenizer = get_tokenizer(model)
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16)
llm = llm.cuda()
pbar = tqdm(total=len(requests))
start = time.time()
batch: List[str] = []
max_prompt_len = 0
max_output_len = 0
for i in range(len(requests)):
prompt, prompt_len, output_len = requests[i]
# Add the prompt to the batch.
batch.append(prompt)
max_prompt_len = max(max_prompt_len, prompt_len)
max_output_len = max(max_output_len, output_len)
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1]
if (max(max_prompt_len, next_prompt_len) + max(
max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=not use_beam_search,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=max_output_len,
)
# Include the decoding time.
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
pbar.update(len(batch))
# Clear the batch.
batch = []
max_prompt_len = 0
max_output_len = 0
end = time.time()
return end - start
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = get_tokenizer(args.model)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.backend == "cacheflow":
elapsed_time = run_cacheflow(
requests, args.model, args.tensor_parallel_size, args.seed, args.n,
args.use_beam_search)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(
len(prompt_token_ids) + output_len
for prompt_token_ids, output_len in requests
prompt_len + output_len
for _, prompt_len, output_len in requests
)
elapsed_time = end - start
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", type=str, choices=["cacheflow", "hf"],
default="cacheflow")
parser.add_argument("--dataset", type=str, required=True,
help="Path to the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
@ -102,5 +203,14 @@ if __name__ == "__main__":
parser.add_argument("--num-prompts", type=int, default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", type=int, default=None,
help="Maximum batch size for HF backend.")
args = parser.parse_args()
if args.backend == "cacheflow":
if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
elif args.backend == "hf":
if args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.")
main(args)