Add throughput benchmarking script (#133)

This commit is contained in:
Woosuk Kwon
2023-05-28 03:20:05 -07:00
committed by GitHub
parent 337871c6fd
commit 211318d44a
12 changed files with 145 additions and 257 deletions

View File

@ -1,5 +1,5 @@
from cacheflow.entrypoints.llm import LLM
from cacheflow.outputs import RequestOutput
from cacheflow.outputs import RequestOutput, CompletionOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer
@ -9,6 +9,7 @@ __all__ = [
"LLM",
"SamplingParams",
"RequestOutput",
"CompletionOutput",
"LLMServer",
"ServerArgs",
"initialize_cluster",

View File

@ -87,6 +87,9 @@ class Scheduler:
def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}

View File

@ -1,5 +1,6 @@
from typing import List, Optional
from typing import List, Optional, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from tqdm import tqdm
from cacheflow.outputs import RequestOutput
@ -31,6 +32,11 @@ class LLM:
self.llm_server = LLMServer.from_server_args(server_args)
self.request_counter = Counter()
def get_tokenizer(
self,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_server.tokenizer
def generate(
self,
prompts: List[str],
@ -41,10 +47,6 @@ class LLM:
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
# Initialize tqdm.
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Processed prompts")
# Add requests to the server.
for i in range(len(prompts)):
prompt = prompts[i]
@ -52,10 +54,24 @@ class LLM:
token_ids = None
else:
token_ids = prompt_token_ids[i]
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params,
token_ids)
self._add_request(prompt, sampling_params, token_ids)
return self._run_server(use_tqdm)
def _add_request(
self,
prompt: str,
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
) -> None:
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params,
prompt_token_ids)
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_server.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts")
# Run the server.
outputs: List[RequestOutput] = []
while self.llm_server.has_unfinished_requests():

View File

@ -151,6 +151,9 @@ class LLMServer:
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def get_num_unfinished_requests(self) -> int:
return self.scheduler.get_num_unfinished_seq_groups()
def has_unfinished_requests(self) -> bool:
return self.scheduler.has_unfinished_seqs()