Add throughput benchmarking script (#133)
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
from cacheflow.entrypoints.llm import LLM
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.outputs import RequestOutput, CompletionOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.llm_server import LLMServer
|
||||
@ -9,6 +9,7 @@ __all__ = [
|
||||
"LLM",
|
||||
"SamplingParams",
|
||||
"RequestOutput",
|
||||
"CompletionOutput",
|
||||
"LLMServer",
|
||||
"ServerArgs",
|
||||
"initialize_cluster",
|
||||
|
||||
@ -87,6 +87,9 @@ class Scheduler:
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
return self.waiting or self.running or self.swapped
|
||||
|
||||
def get_num_unfinished_seq_groups(self) -> int:
|
||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||
|
||||
def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
|
||||
# Blocks that need to be swaped or copied before model execution.
|
||||
blocks_to_swap_in: Dict[int, int] = {}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from tqdm import tqdm
|
||||
|
||||
from cacheflow.outputs import RequestOutput
|
||||
@ -31,6 +32,11 @@ class LLM:
|
||||
self.llm_server = LLMServer.from_server_args(server_args)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
return self.llm_server.tokenizer
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -41,10 +47,6 @@ class LLM:
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
pbar = tqdm(total=len(prompts), desc="Processed prompts")
|
||||
|
||||
# Add requests to the server.
|
||||
for i in range(len(prompts)):
|
||||
prompt = prompts[i]
|
||||
@ -52,10 +54,24 @@ class LLM:
|
||||
token_ids = None
|
||||
else:
|
||||
token_ids = prompt_token_ids[i]
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_server.add_request(request_id, prompt, sampling_params,
|
||||
token_ids)
|
||||
self._add_request(prompt, sampling_params, token_ids)
|
||||
return self._run_server(use_tqdm)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_server.add_request(request_id, prompt, sampling_params,
|
||||
prompt_token_ids)
|
||||
|
||||
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_server.get_num_unfinished_requests()
|
||||
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
||||
# Run the server.
|
||||
outputs: List[RequestOutput] = []
|
||||
while self.llm_server.has_unfinished_requests():
|
||||
|
||||
@ -151,6 +151,9 @@ class LLMServer:
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.scheduler.get_num_unfinished_seq_groups()
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return self.scheduler.has_unfinished_seqs()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user