From 017ef648e9eaa8c70a6e32532c861aed4a2ae463 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Thu, 12 Jun 2025 13:30:56 -0400 Subject: [PATCH] [Spec Decode][Benchmark] Generalize spec decode offline benchmark to more methods and datasets (#18847) --- examples/offline_inference/eagle.py | 4 + examples/offline_inference/spec_decode.py | 137 ++++++++++++ tests/benchmarks/test_serve_cli.py | 2 + vllm/benchmarks/datasets.py | 252 ++++++++++++++++++++++ vllm/benchmarks/serve.py | 231 +------------------- 5 files changed, 403 insertions(+), 223 deletions(-) create mode 100644 examples/offline_inference/spec_decode.py diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index ce977ee99b..f4193fdb8b 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -137,4 +137,8 @@ def main(): if __name__ == "__main__": + print( + "[WARNING] Use examples/offline_inference/spec_decode.py" + " instead of this script." + ) main() diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py new file mode 100644 index 0000000000..eece8beced --- /dev/null +++ b/examples/offline_inference/spec_decode.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 + +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.benchmarks.datasets import add_dataset_parser, get_samples +from vllm.v1.metrics.reader import Counter, Vector + +try: + from vllm.utils import FlexibleArgumentParser +except ImportError: + from argparse import ArgumentParser as FlexibleArgumentParser + + +def parse_args(): + parser = FlexibleArgumentParser() + add_dataset_parser(parser) + parser.add_argument( + "--dataset", + type=str, + default="./examples/data/gsm8k.jsonl", + help="downloaded from the eagle repo " + "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/", + ) + parser.add_argument( + "--method", type=str, default="eagle", choices=["ngram", "eagle", "eagle3"] + ) + parser.add_argument("--max-num-seqs", type=int, default=8) + parser.add_argument("--num-spec-tokens", type=int, default=2) + parser.add_argument("--prompt-lookup-max", type=int, default=5) + parser.add_argument("--prompt-lookup-min", type=int, default=2) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--draft-tp", type=int, default=1) + parser.add_argument("--enforce-eager", action="store_true") + parser.add_argument("--enable-chunked-prefill", action="store_true") + parser.add_argument("--max-num-batched-tokens", type=int, default=2048) + parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--top-k", type=int, default=-1) + parser.add_argument("--print-output", action="store_true") + parser.add_argument("--output-len", type=int, default=256) + return parser.parse_args() + + +def main(): + args = parse_args() + args.endpoint_type = "openai-chat" + + model_dir = "meta-llama/Llama-3.1-8B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_dir) + max_model_len = 2048 + + prompts = get_samples(args, tokenizer) + # add_special_tokens is False to avoid adding bos twice when using chat templates + prompt_ids = [ + tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts + ] + + if args.method == "eagle" or args.method == "eagle3": + if args.method == "eagle": + eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + elif args.method == "eagle3": + eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + speculative_config = { + "method": args.method, + "model": eagle_dir, + "num_speculative_tokens": args.num_spec_tokens, + "draft_tensor_parallel_size": args.draft_tp, + "max_model_len": max_model_len, + } + elif args.method == "ngram": + speculative_config = { + "method": "ngram", + "num_speculative_tokens": args.num_spec_tokens, + "prompt_lookup_max": args.prompt_lookup_max, + "prompt_lookup_min": args.prompt_lookup_min, + "max_model_len": max_model_len, + } + else: + raise ValueError(f"unknown method: {args.method}") + + llm = LLM( + model=model_dir, + trust_remote_code=True, + tensor_parallel_size=args.tp, + enable_chunked_prefill=args.enable_chunked_prefill, + max_num_batched_tokens=args.max_num_batched_tokens, + enforce_eager=args.enforce_eager, + max_model_len=max_model_len, + max_num_seqs=args.max_num_seqs, + gpu_memory_utilization=0.8, + speculative_config=speculative_config, + disable_log_stats=False, + ) + + sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) + outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + + # print the generated text + if args.print_output: + for output in outputs: + print("-" * 50) + print(f"prompt: {output.prompt}") + print(f"generated text: {output.outputs[0].text}") + print("-" * 50) + + try: + metrics = llm.get_metrics() + except AssertionError: + print("Metrics are not supported in the V0 engine.") + return + + num_drafts = num_accepted = 0 + acceptance_counts = [0] * args.num_spec_tokens + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens": + assert isinstance(metric, Counter) + num_accepted += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(len(metric.values)): + acceptance_counts[pos] += metric.values[pos] + + print("-" * 50) + print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}") + print("-" * 50) + + # print acceptance at each token position + for i in range(len(acceptance_counts)): + print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}") + + +if __name__ == "__main__": + main() diff --git a/tests/benchmarks/test_serve_cli.py b/tests/benchmarks/test_serve_cli.py index a318195267..bfcf274727 100644 --- a/tests/benchmarks/test_serve_cli.py +++ b/tests/benchmarks/test_serve_cli.py @@ -31,6 +31,8 @@ def test_bench_serve(server): server.host, "--port", str(server.port), + "--dataset-name", + "random", "--random-input-len", "32", "--random-output-len", diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 4da9f7368e..3efbe56957 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -50,6 +50,11 @@ try: except ImportError: librosa = PlaceholderModule("librosa") +try: + from vllm.utils import FlexibleArgumentParser +except ImportError: + from argparse import ArgumentParser as FlexibleArgumentParser + logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- @@ -458,6 +463,253 @@ class ShareGPTDataset(BenchmarkDataset): return samples +def add_dataset_parser(parser: FlexibleArgumentParser): + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="random", + choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--dataset-path", + type=str, + default=None, + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.", + ) + + # group for dataset specific arguments + custom_group = parser.add_argument_group("custom dataset options") + custom_group.add_argument( + "--custom-output-len", + type=int, + default=256, + help= + "Number of output tokens per request, used only for custom dataset.", + ) + custom_group.add_argument( + "--custom-skip-chat-template", + action="store_true", + help= + "Skip applying chat template to prompt, used only for custom dataset.", + ) + + sonnet_group = parser.add_argument_group("sonnet dataset options") + sonnet_group.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help= + "Number of input tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help= + "Number of output tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help= + "Number of prefix tokens per request, used only for sonnet dataset.", + ) + + sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.", + ) + + random_group = parser.add_argument_group("random dataset options") + random_group.add_argument( + "--random-input-len", + type=int, + default=1024, + help= + "Number of input tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-output-len", + type=int, + default=128, + help= + "Number of output tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range ratio for sampling input/output length, " + "used only for random sampling. Must be in the range [0, 1) to define " + "a symmetric sampling range" + "[length * (1 - range_ratio), length * (1 + range_ratio)].", + ) + random_group.add_argument( + "--random-prefix-len", + type=int, + default=0, + help=("Number of fixed prefix tokens before the random context " + "in a request. " + "The total input length is the sum of `random-prefix-len` and " + "a random " + "context length sampled from [input_len * (1 - range_ratio), " + "input_len * (1 + range_ratio)]."), + ) + + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) + + +def get_samples(args, tokenizer) -> list[SampleRequest]: + if args.dataset_name == "custom": + dataset = CustomDataset(dataset_path=args.dataset_path) + input_requests = dataset.sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.custom_output_len, + skip_chat_template=args.custom_skip_chat_template, + ) + + elif args.dataset_name == "sonnet": + dataset = SonnetDataset(dataset_path=args.dataset_path) + # For the "sonnet" dataset, formatting depends on the backend. + if args.endpoint_type == "openai-chat": + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=False, + ) + else: + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=True, + ) + + elif args.dataset_name == "hf": + # all following datasets are implemented from the + # HuggingFaceDataset base class + if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + dataset_class = VisionArenaDataset + args.hf_split = "train" + args.hf_subset = None + elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + dataset_class = InstructCoderDataset + args.hf_split = "train" + elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: + dataset_class = MTBenchDataset + args.hf_split = "train" + elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + dataset_class = ConversationDataset + elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + dataset_class = AIMODataset + args.hf_split = "train" + elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 + dataset_class = NextEditPredictionDataset + args.hf_split = "train" + elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: + dataset_class = ASRDataset + args.hf_split = "train" + else: + supported_datasets = set([ + dataset_name for cls in HuggingFaceDataset.__subclasses__() + for dataset_name in cls.SUPPORTED_DATASET_PATHS + ]) + raise ValueError( + f"Unsupported dataset path: {args.dataset_path}. " + "Huggingface dataset only supports dataset_path" + f" from one of following: {supported_datasets}. " + "Please consider contributing if you would " + "like to add support for additional dataset formats.") + + if dataset_class.IS_MULTIMODAL and args.endpoint_type not in [ + "openai-chat", + "openai-audio", + ]: + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' and " + "'openai-audio' backend.") + input_requests = dataset_class( + dataset_path=args.dataset_path, + dataset_subset=args.hf_subset, + dataset_split=args.hf_split, + random_seed=args.seed, + ).sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.hf_output_len, + ) + + else: + # For datasets that follow a similar structure, use a mapping. + dataset_mapping = { + "sharegpt": + lambda: ShareGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), + "burstgpt": + lambda: BurstGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path). + sample(tokenizer=tokenizer, num_requests=args.num_prompts), + "random": + lambda: RandomDataset(dataset_path=args.dataset_path).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + input_len=args.random_input_len, + output_len=args.random_output_len, + range_ratio=args.random_range_ratio, + ), + } + + try: + input_requests = dataset_mapping[args.dataset_name]() + except KeyError as err: + raise ValueError(f"Unknown dataset: {args.dataset_name}") from err + + return input_requests + + # ----------------------------------------------------------------------------- # Custom Dataset Implementation # ----------------------------------------------------------------------------- diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 019ebcf8d5..4487d2d684 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -32,12 +32,8 @@ import numpy as np from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase -from vllm.benchmarks.datasets import (AIMODataset, ASRDataset, BurstGPTDataset, - ConversationDataset, HuggingFaceDataset, - InstructCoderDataset, MTBenchDataset, - NextEditPredictionDataset, RandomDataset, - SampleRequest, ShareGPTDataset, - SonnetDataset, VisionArenaDataset) +from vllm.benchmarks.datasets import (SampleRequest, add_dataset_parser, + get_samples) from vllm.benchmarks.endpoint_request_func import (ASYNC_REQUEST_FUNCS, OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput, @@ -543,6 +539,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, def add_cli_args(parser: argparse.ArgumentParser): + add_dataset_parser(parser) parser.add_argument( "--endpoint-type", type=str, @@ -571,20 +568,6 @@ def add_cli_args(parser: argparse.ArgumentParser): default="/v1/completions", help="API endpoint.", ) - parser.add_argument( - "--dataset-name", - type=str, - default="random", - choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], - help="Name of the dataset to benchmark on.", - ) - parser.add_argument( - "--dataset-path", - type=str, - default=None, - help="Path to the sharegpt/sonnet dataset. " - "Or the huggingface dataset ID if using HF dataset.", - ) parser.add_argument( "--max-concurrency", type=int, @@ -611,12 +594,6 @@ def add_cli_args(parser: argparse.ArgumentParser): "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument("--use-beam-search", action="store_true") - parser.add_argument( - "--num-prompts", - type=int, - default=1000, - help="Number of prompts to process.", - ) parser.add_argument( "--logprobs", type=int, @@ -648,7 +625,6 @@ def add_cli_args(parser: argparse.ArgumentParser): "bursty requests. A higher burstiness value (burstiness > 1) " "results in a more uniform arrival of requests.", ) - parser.add_argument("--seed", type=int, default=0) parser.add_argument( "--trust-remote-code", action="store_true", @@ -739,89 +715,6 @@ def add_cli_args(parser: argparse.ArgumentParser): "and the blog: https://hao-ai-lab.github.io/blogs/distserve", ) - # group for dataset specific arguments - sonnet_group = parser.add_argument_group("sonnet dataset options") - sonnet_group.add_argument( - "--sonnet-input-len", - type=int, - default=550, - help= - "Number of input tokens per request, used only for sonnet dataset.", - ) - sonnet_group.add_argument( - "--sonnet-output-len", - type=int, - default=150, - help= - "Number of output tokens per request, used only for sonnet dataset.", - ) - sonnet_group.add_argument( - "--sonnet-prefix-len", - type=int, - default=200, - help= - "Number of prefix tokens per request, used only for sonnet dataset.", - ) - - sharegpt_group = parser.add_argument_group("sharegpt dataset options") - sharegpt_group.add_argument( - "--sharegpt-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.", - ) - - random_group = parser.add_argument_group("random dataset options") - random_group.add_argument( - "--random-input-len", - type=int, - default=1024, - help= - "Number of input tokens per request, used only for random sampling.", - ) - random_group.add_argument( - "--random-output-len", - type=int, - default=128, - help= - "Number of output tokens per request, used only for random sampling.", - ) - random_group.add_argument( - "--random-range-ratio", - type=float, - default=0.0, - help="Range ratio for sampling input/output length, " - "used only for random sampling. Must be in the range [0, 1) to define " - "a symmetric sampling range" - "[length * (1 - range_ratio), length * (1 + range_ratio)].", - ) - random_group.add_argument( - "--random-prefix-len", - type=int, - default=0, - help="Number of fixed prefix tokens before random " - " context. The length range of context in a random " - " request is [random-prefix-len, " - " random-prefix-len + random-prefix-len * random-range-ratio).") - - hf_group = parser.add_argument_group("hf dataset options") - hf_group.add_argument("--hf-subset", - type=str, - default=None, - help="Subset of the HF dataset.") - hf_group.add_argument("--hf-split", - type=str, - default=None, - help="Split of the HF dataset.") - hf_group.add_argument( - "--hf-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output lengths " - "from the sampled HF dataset.", - ) - sampling_group = parser.add_argument_group("sampling parameters") sampling_group.add_argument( "--top-p", @@ -884,7 +777,6 @@ def main(args: argparse.Namespace): random.seed(args.seed) np.random.seed(args.seed) - endpoint_type = args.endpoint_type label = args.label model_id = args.model model_name = args.served_model_name @@ -907,115 +799,8 @@ def main(args: argparse.Namespace): "Please specify '--dataset-name' and the corresponding " "'--dataset-path' if required.") - if args.dataset_name == "sonnet": - dataset = SonnetDataset(dataset_path=args.dataset_path) - # For the "sonnet" dataset, formatting depends on the backend. - if args.backend == "openai-chat": - input_requests = dataset.sample( - num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=False, - ) - else: - assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset.") - input_requests = dataset.sample( - num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=True, - ) - - elif args.dataset_name == "hf": - # all following datasets are implemented from the - # HuggingFaceDataset base class - if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: - dataset_class = VisionArenaDataset - args.hf_split = "train" - args.hf_subset = None - elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: - dataset_class = InstructCoderDataset - args.hf_split = "train" - elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: - dataset_class = MTBenchDataset - args.hf_split = "train" - elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: - dataset_class = ConversationDataset - args.hf_split = "train" - elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: - dataset_class = AIMODataset - args.hf_split = "train" - elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 - dataset_class = NextEditPredictionDataset - args.hf_split = "train" - elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: - dataset_class = ASRDataset - args.hf_split = "train" - else: - supported_datasets = set([ - dataset_name for cls in HuggingFaceDataset.__subclasses__() - for dataset_name in cls.SUPPORTED_DATASET_PATHS - ]) - raise ValueError( - f"Unsupported dataset path: {args.dataset_path}. " - "Huggingface dataset only supports dataset_path" - f" from one of following: {supported_datasets}. " - "Please consider contributing if you would " - "like to add support for additional dataset formats.") - - if dataset_class.IS_MULTIMODAL and endpoint_type not in [ - "openai-chat", - "openai-audio", - ]: - # multi-modal benchmark is only available on OpenAI Chat backend. - raise ValueError( - "Multi-modal content is only supported on 'openai-chat' and " - "'openai-audio' backend.") - input_requests = dataset_class( - dataset_path=args.dataset_path, - dataset_subset=args.hf_subset, - dataset_split=args.hf_split, - random_seed=args.seed, - ).sample( - num_requests=args.num_prompts, - tokenizer=tokenizer, - output_len=args.hf_output_len, - ) - - else: - # For datasets that follow a similar structure, use a mapping. - dataset_mapping = { - "sharegpt": - lambda: ShareGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - output_len=args.sharegpt_output_len, - ), - "burstgpt": - lambda: BurstGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path). - sample(tokenizer=tokenizer, num_requests=args.num_prompts), - "random": - lambda: RandomDataset(dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - prefix_len=args.random_prefix_len, - input_len=args.random_input_len, - output_len=args.random_output_len, - range_ratio=args.random_range_ratio, - ), - } - - try: - input_requests = dataset_mapping[args.dataset_name]() - except KeyError as err: - raise ValueError(f"Unknown dataset: {args.dataset_name}") from err + # Load the dataset. + input_requests = get_samples(args, tokenizer) goodput_config_dict = check_goodput_args(args) # Collect the sampling parameters. @@ -1043,7 +828,7 @@ def main(args: argparse.Namespace): benchmark_result = asyncio.run( benchmark( - endpoint_type=endpoint_type, + endpoint_type=args.endpoint_type, api_url=api_url, base_url=base_url, model_id=model_id, @@ -1073,7 +858,7 @@ def main(args: argparse.Namespace): # Setup current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") result_json["date"] = current_dt - result_json["endpoint_type"] = endpoint_type + result_json["endpoint_type"] = args.endpoint_type result_json["label"] = label result_json["model_id"] = model_id result_json["tokenizer_id"] = tokenizer_id @@ -1118,7 +903,7 @@ def main(args: argparse.Namespace): base_model_id = model_id.split("/")[-1] max_concurrency_str = (f"-concurrency{args.max_concurrency}" if args.max_concurrency is not None else "") - label = label or endpoint_type + label = label or args.endpoint_type file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa if args.result_filename: file_name = args.result_filename