[Spec Decode][Benchmark] Generalize spec decode offline benchmark to more methods and datasets (#18847)
This commit is contained in:
@ -137,4 +137,8 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(
|
||||
"[WARNING] Use examples/offline_inference/spec_decode.py"
|
||||
" instead of this script."
|
||||
)
|
||||
main()
|
||||
|
||||
137
examples/offline_inference/spec_decode.py
Normal file
137
examples/offline_inference/spec_decode.py
Normal file
@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
|
||||
from vllm.v1.metrics.reader import Counter, Vector
|
||||
|
||||
try:
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
add_dataset_parser(parser)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="./examples/data/gsm8k.jsonl",
|
||||
help="downloaded from the eagle repo "
|
||||
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--method", type=str, default="eagle", choices=["ngram", "eagle", "eagle3"]
|
||||
)
|
||||
parser.add_argument("--max-num-seqs", type=int, default=8)
|
||||
parser.add_argument("--num-spec-tokens", type=int, default=2)
|
||||
parser.add_argument("--prompt-lookup-max", type=int, default=5)
|
||||
parser.add_argument("--prompt-lookup-min", type=int, default=2)
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--draft-tp", type=int, default=1)
|
||||
parser.add_argument("--enforce-eager", action="store_true")
|
||||
parser.add_argument("--enable-chunked-prefill", action="store_true")
|
||||
parser.add_argument("--max-num-batched-tokens", type=int, default=2048)
|
||||
parser.add_argument("--temp", type=float, default=0)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=-1)
|
||||
parser.add_argument("--print-output", action="store_true")
|
||||
parser.add_argument("--output-len", type=int, default=256)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
args.endpoint_type = "openai-chat"
|
||||
|
||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
max_model_len = 2048
|
||||
|
||||
prompts = get_samples(args, tokenizer)
|
||||
# add_special_tokens is False to avoid adding bos twice when using chat templates
|
||||
prompt_ids = [
|
||||
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
|
||||
]
|
||||
|
||||
if args.method == "eagle" or args.method == "eagle3":
|
||||
if args.method == "eagle":
|
||||
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
elif args.method == "eagle3":
|
||||
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
speculative_config = {
|
||||
"method": args.method,
|
||||
"model": eagle_dir,
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
"draft_tensor_parallel_size": args.draft_tp,
|
||||
"max_model_len": max_model_len,
|
||||
}
|
||||
elif args.method == "ngram":
|
||||
speculative_config = {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
"prompt_lookup_max": args.prompt_lookup_max,
|
||||
"prompt_lookup_min": args.prompt_lookup_min,
|
||||
"max_model_len": max_model_len,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"unknown method: {args.method}")
|
||||
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=args.tp,
|
||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||
enforce_eager=args.enforce_eager,
|
||||
max_model_len=max_model_len,
|
||||
max_num_seqs=args.max_num_seqs,
|
||||
gpu_memory_utilization=0.8,
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
|
||||
|
||||
# print the generated text
|
||||
if args.print_output:
|
||||
for output in outputs:
|
||||
print("-" * 50)
|
||||
print(f"prompt: {output.prompt}")
|
||||
print(f"generated text: {output.outputs[0].text}")
|
||||
print("-" * 50)
|
||||
|
||||
try:
|
||||
metrics = llm.get_metrics()
|
||||
except AssertionError:
|
||||
print("Metrics are not supported in the V0 engine.")
|
||||
return
|
||||
|
||||
num_drafts = num_accepted = 0
|
||||
acceptance_counts = [0] * args.num_spec_tokens
|
||||
for metric in metrics:
|
||||
if metric.name == "vllm:spec_decode_num_drafts":
|
||||
assert isinstance(metric, Counter)
|
||||
num_drafts += metric.value
|
||||
elif metric.name == "vllm:spec_decode_num_accepted_tokens":
|
||||
assert isinstance(metric, Counter)
|
||||
num_accepted += metric.value
|
||||
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
|
||||
assert isinstance(metric, Vector)
|
||||
for pos in range(len(metric.values)):
|
||||
acceptance_counts[pos] += metric.values[pos]
|
||||
|
||||
print("-" * 50)
|
||||
print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}")
|
||||
print("-" * 50)
|
||||
|
||||
# print acceptance at each token position
|
||||
for i in range(len(acceptance_counts)):
|
||||
print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -31,6 +31,8 @@ def test_bench_serve(server):
|
||||
server.host,
|
||||
"--port",
|
||||
str(server.port),
|
||||
"--dataset-name",
|
||||
"random",
|
||||
"--random-input-len",
|
||||
"32",
|
||||
"--random-output-len",
|
||||
|
||||
@ -50,6 +50,11 @@ try:
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa")
|
||||
|
||||
try:
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -458,6 +463,253 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
return samples
|
||||
|
||||
|
||||
def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="random",
|
||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the sharegpt/sonnet dataset. "
|
||||
"Or the huggingface dataset ID if using HF dataset.",
|
||||
)
|
||||
|
||||
# group for dataset specific arguments
|
||||
custom_group = parser.add_argument_group("custom dataset options")
|
||||
custom_group.add_argument(
|
||||
"--custom-output-len",
|
||||
type=int,
|
||||
default=256,
|
||||
help=
|
||||
"Number of output tokens per request, used only for custom dataset.",
|
||||
)
|
||||
custom_group.add_argument(
|
||||
"--custom-skip-chat-template",
|
||||
action="store_true",
|
||||
help=
|
||||
"Skip applying chat template to prompt, used only for custom dataset.",
|
||||
)
|
||||
|
||||
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-input-len",
|
||||
type=int,
|
||||
default=550,
|
||||
help=
|
||||
"Number of input tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-output-len",
|
||||
type=int,
|
||||
default=150,
|
||||
help=
|
||||
"Number of output tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-prefix-len",
|
||||
type=int,
|
||||
default=200,
|
||||
help=
|
||||
"Number of prefix tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
|
||||
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
|
||||
sharegpt_group.add_argument(
|
||||
"--sharegpt-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the output length "
|
||||
"from the ShareGPT dataset.",
|
||||
)
|
||||
|
||||
random_group = parser.add_argument_group("random dataset options")
|
||||
random_group.add_argument(
|
||||
"--random-input-len",
|
||||
type=int,
|
||||
default=1024,
|
||||
help=
|
||||
"Number of input tokens per request, used only for random sampling.",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-output-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help=
|
||||
"Number of output tokens per request, used only for random sampling.",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Range ratio for sampling input/output length, "
|
||||
"used only for random sampling. Must be in the range [0, 1) to define "
|
||||
"a symmetric sampling range"
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Number of fixed prefix tokens before the random context "
|
||||
"in a request. "
|
||||
"The total input length is the sum of `random-prefix-len` and "
|
||||
"a random "
|
||||
"context length sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]."),
|
||||
)
|
||||
|
||||
hf_group = parser.add_argument_group("hf dataset options")
|
||||
hf_group.add_argument("--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HF dataset.")
|
||||
hf_group.add_argument("--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HF dataset.")
|
||||
hf_group.add_argument(
|
||||
"--hf-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the output lengths "
|
||||
"from the sampled HF dataset.",
|
||||
)
|
||||
|
||||
|
||||
def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
if args.dataset_name == "custom":
|
||||
dataset = CustomDataset(dataset_path=args.dataset_path)
|
||||
input_requests = dataset.sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.custom_output_len,
|
||||
skip_chat_template=args.custom_skip_chat_template,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sonnet":
|
||||
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
||||
# For the "sonnet" dataset, formatting depends on the backend.
|
||||
if args.endpoint_type == "openai-chat":
|
||||
input_requests = dataset.sample(
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=False,
|
||||
)
|
||||
else:
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
input_requests = dataset.sample(
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=True,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "hf":
|
||||
# all following datasets are implemented from the
|
||||
# HuggingFaceDataset base class
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = VisionArenaDataset
|
||||
args.hf_split = "train"
|
||||
args.hf_subset = None
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = InstructCoderDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = MTBenchDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ConversationDataset
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = AIMODataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501
|
||||
dataset_class = NextEditPredictionDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ASRDataset
|
||||
args.hf_split = "train"
|
||||
else:
|
||||
supported_datasets = set([
|
||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||
for dataset_name in cls.SUPPORTED_DATASET_PATHS
|
||||
])
|
||||
raise ValueError(
|
||||
f"Unsupported dataset path: {args.dataset_path}. "
|
||||
"Huggingface dataset only supports dataset_path"
|
||||
f" from one of following: {supported_datasets}. "
|
||||
"Please consider contributing if you would "
|
||||
"like to add support for additional dataset formats.")
|
||||
|
||||
if dataset_class.IS_MULTIMODAL and args.endpoint_type not in [
|
||||
"openai-chat",
|
||||
"openai-audio",
|
||||
]:
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' and "
|
||||
"'openai-audio' backend.")
|
||||
input_requests = dataset_class(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
dataset_split=args.hf_split,
|
||||
random_seed=args.seed,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.hf_output_len,
|
||||
)
|
||||
|
||||
else:
|
||||
# For datasets that follow a similar structure, use a mapping.
|
||||
dataset_mapping = {
|
||||
"sharegpt":
|
||||
lambda: ShareGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
output_len=args.sharegpt_output_len,
|
||||
),
|
||||
"burstgpt":
|
||||
lambda: BurstGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).
|
||||
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
|
||||
"random":
|
||||
lambda: RandomDataset(dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
prefix_len=args.random_prefix_len,
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
range_ratio=args.random_range_ratio,
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
input_requests = dataset_mapping[args.dataset_name]()
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
||||
|
||||
return input_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Custom Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@ -32,12 +32,8 @@ import numpy as np
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.benchmarks.datasets import (AIMODataset, ASRDataset, BurstGPTDataset,
|
||||
ConversationDataset, HuggingFaceDataset,
|
||||
InstructCoderDataset, MTBenchDataset,
|
||||
NextEditPredictionDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
from vllm.benchmarks.datasets import (SampleRequest, add_dataset_parser,
|
||||
get_samples)
|
||||
from vllm.benchmarks.endpoint_request_func import (ASYNC_REQUEST_FUNCS,
|
||||
OPENAI_COMPATIBLE_BACKENDS,
|
||||
RequestFuncInput,
|
||||
@ -543,6 +539,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
add_dataset_parser(parser)
|
||||
parser.add_argument(
|
||||
"--endpoint-type",
|
||||
type=str,
|
||||
@ -571,20 +568,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
default="/v1/completions",
|
||||
help="API endpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="random",
|
||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the sharegpt/sonnet dataset. "
|
||||
"Or the huggingface dataset ID if using HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
@ -611,12 +594,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logprobs",
|
||||
type=int,
|
||||
@ -648,7 +625,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
"bursty requests. A higher burstiness value (burstiness > 1) "
|
||||
"results in a more uniform arrival of requests.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
@ -739,89 +715,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
|
||||
)
|
||||
|
||||
# group for dataset specific arguments
|
||||
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-input-len",
|
||||
type=int,
|
||||
default=550,
|
||||
help=
|
||||
"Number of input tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-output-len",
|
||||
type=int,
|
||||
default=150,
|
||||
help=
|
||||
"Number of output tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-prefix-len",
|
||||
type=int,
|
||||
default=200,
|
||||
help=
|
||||
"Number of prefix tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
|
||||
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
|
||||
sharegpt_group.add_argument(
|
||||
"--sharegpt-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the output length "
|
||||
"from the ShareGPT dataset.",
|
||||
)
|
||||
|
||||
random_group = parser.add_argument_group("random dataset options")
|
||||
random_group.add_argument(
|
||||
"--random-input-len",
|
||||
type=int,
|
||||
default=1024,
|
||||
help=
|
||||
"Number of input tokens per request, used only for random sampling.",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-output-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help=
|
||||
"Number of output tokens per request, used only for random sampling.",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Range ratio for sampling input/output length, "
|
||||
"used only for random sampling. Must be in the range [0, 1) to define "
|
||||
"a symmetric sampling range"
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of fixed prefix tokens before random "
|
||||
" context. The length range of context in a random "
|
||||
" request is [random-prefix-len, "
|
||||
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
||||
|
||||
hf_group = parser.add_argument_group("hf dataset options")
|
||||
hf_group.add_argument("--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HF dataset.")
|
||||
hf_group.add_argument("--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HF dataset.")
|
||||
hf_group.add_argument(
|
||||
"--hf-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the output lengths "
|
||||
"from the sampled HF dataset.",
|
||||
)
|
||||
|
||||
sampling_group = parser.add_argument_group("sampling parameters")
|
||||
sampling_group.add_argument(
|
||||
"--top-p",
|
||||
@ -884,7 +777,6 @@ def main(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
endpoint_type = args.endpoint_type
|
||||
label = args.label
|
||||
model_id = args.model
|
||||
model_name = args.served_model_name
|
||||
@ -907,115 +799,8 @@ def main(args: argparse.Namespace):
|
||||
"Please specify '--dataset-name' and the corresponding "
|
||||
"'--dataset-path' if required.")
|
||||
|
||||
if args.dataset_name == "sonnet":
|
||||
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
||||
# For the "sonnet" dataset, formatting depends on the backend.
|
||||
if args.backend == "openai-chat":
|
||||
input_requests = dataset.sample(
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=False,
|
||||
)
|
||||
else:
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
input_requests = dataset.sample(
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=True,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "hf":
|
||||
# all following datasets are implemented from the
|
||||
# HuggingFaceDataset base class
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = VisionArenaDataset
|
||||
args.hf_split = "train"
|
||||
args.hf_subset = None
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = InstructCoderDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = MTBenchDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ConversationDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = AIMODataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501
|
||||
dataset_class = NextEditPredictionDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ASRDataset
|
||||
args.hf_split = "train"
|
||||
else:
|
||||
supported_datasets = set([
|
||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||
for dataset_name in cls.SUPPORTED_DATASET_PATHS
|
||||
])
|
||||
raise ValueError(
|
||||
f"Unsupported dataset path: {args.dataset_path}. "
|
||||
"Huggingface dataset only supports dataset_path"
|
||||
f" from one of following: {supported_datasets}. "
|
||||
"Please consider contributing if you would "
|
||||
"like to add support for additional dataset formats.")
|
||||
|
||||
if dataset_class.IS_MULTIMODAL and endpoint_type not in [
|
||||
"openai-chat",
|
||||
"openai-audio",
|
||||
]:
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' and "
|
||||
"'openai-audio' backend.")
|
||||
input_requests = dataset_class(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
dataset_split=args.hf_split,
|
||||
random_seed=args.seed,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.hf_output_len,
|
||||
)
|
||||
|
||||
else:
|
||||
# For datasets that follow a similar structure, use a mapping.
|
||||
dataset_mapping = {
|
||||
"sharegpt":
|
||||
lambda: ShareGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
output_len=args.sharegpt_output_len,
|
||||
),
|
||||
"burstgpt":
|
||||
lambda: BurstGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).
|
||||
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
|
||||
"random":
|
||||
lambda: RandomDataset(dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
prefix_len=args.random_prefix_len,
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
range_ratio=args.random_range_ratio,
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
input_requests = dataset_mapping[args.dataset_name]()
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
||||
# Load the dataset.
|
||||
input_requests = get_samples(args, tokenizer)
|
||||
goodput_config_dict = check_goodput_args(args)
|
||||
|
||||
# Collect the sampling parameters.
|
||||
@ -1043,7 +828,7 @@ def main(args: argparse.Namespace):
|
||||
|
||||
benchmark_result = asyncio.run(
|
||||
benchmark(
|
||||
endpoint_type=endpoint_type,
|
||||
endpoint_type=args.endpoint_type,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
@ -1073,7 +858,7 @@ def main(args: argparse.Namespace):
|
||||
# Setup
|
||||
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
result_json["date"] = current_dt
|
||||
result_json["endpoint_type"] = endpoint_type
|
||||
result_json["endpoint_type"] = args.endpoint_type
|
||||
result_json["label"] = label
|
||||
result_json["model_id"] = model_id
|
||||
result_json["tokenizer_id"] = tokenizer_id
|
||||
@ -1118,7 +903,7 @@ def main(args: argparse.Namespace):
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
||||
if args.max_concurrency is not None else "")
|
||||
label = label or endpoint_type
|
||||
label = label or args.endpoint_type
|
||||
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
|
||||
Reference in New Issue
Block a user