Convert benchmarks to ruff format (#18068)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-13 14:43:29 +01:00
committed by GitHub
parent b922c2ebd2
commit 009d9e7590
41 changed files with 3980 additions and 2938 deletions

View File

@ -1,13 +1,9 @@
# This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the
# following differences:
# - isort profile is set to black
# - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed
[tool.isort]
profile = "black"
[tool.ruff]
line-length = 88
exclude = [

View File

@ -17,7 +17,7 @@ repos:
- id: ruff
args: [--output-format, github, --fix]
- id: ruff-format
files: ^(.buildkite).*
files: ^(.buildkite|benchmarks)/.*
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
hooks:
@ -28,8 +28,6 @@ repos:
rev: 6.0.1
hooks:
- id: isort
# necessary during the transition from yapf to ruff format
args: [--resolve-all-configs, --config-root, .]
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v20.1.3
hooks:

View File

@ -12,8 +12,7 @@ from typing import Optional, Union
import aiohttp
import huggingface_hub.constants
from tqdm.asyncio import tqdm
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
# NOTE(simon): do not import vLLM here so the benchmark script
# can run without vLLM installed.
@ -43,8 +42,7 @@ class RequestFuncOutput:
latency: float = 0.0
output_tokens: int = 0
ttft: float = 0.0 # Time to first token
itl: list[float] = field(
default_factory=list) # list of inter-token latencies
itl: list[float] = field(default_factory=list) # list of inter-token latencies
tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0
error: str = ""
@ -57,8 +55,9 @@ async def async_request_tgi(
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(
trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
params = {
"max_new_tokens": request_func_input.output_len,
"do_sample": True,
@ -105,8 +104,7 @@ async def async_request_tgi(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
@ -133,8 +131,9 @@ async def async_request_trt_llm(
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(
trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
@ -159,8 +158,7 @@ async def async_request_trt_llm(
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data:")
chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
data = json.loads(chunk)
output.generated_text += data["text_output"]
@ -172,8 +170,7 @@ async def async_request_trt_llm(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
@ -197,9 +194,9 @@ async def async_request_deepspeed_mii(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(
trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
payload = {
"model": request_func_input.model,
"prompt": request_func_input.prompt,
@ -217,19 +214,21 @@ async def async_request_deepspeed_mii(
st = time.perf_counter()
try:
async with session.post(url=request_func_input.api_url,
json=payload) as response:
async with session.post(
url=request_func_input.api_url, json=payload
) as response:
if response.status == 200:
parsed_resp = await response.json()
output.latency = time.perf_counter() - st
if "choices" in parsed_resp:
output.generated_text = parsed_resp["choices"][0][
"text"]
output.generated_text = parsed_resp["choices"][0]["text"]
elif "text" in parsed_resp:
output.generated_text = parsed_resp["text"][0]
else:
output.error = ("Unexpected response format: "
"neither 'choices' nor 'text' found")
output.error = (
"Unexpected response format: "
"neither 'choices' nor 'text' found"
)
output.success = False
output.success = True
else:
@ -250,15 +249,17 @@ async def async_request_openai_completions(
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
assert api_url.endswith(("completions", "profile")), (
"OpenAI Completions API URL must end with 'completions' or 'profile'."
)
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(
trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"repetition_penalty": 1.0,
@ -273,9 +274,7 @@ async def async_request_openai_completions(
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
@ -284,8 +283,9 @@ async def async_request_openai_completions(
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
first_chunk_received = False
async for chunk_bytes in response.content:
@ -293,8 +293,7 @@ async def async_request_openai_completions(
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
data = json.loads(chunk)
@ -314,21 +313,20 @@ async def async_request_openai_completions(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += text or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
output.output_tokens = usage.get("completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
"This response will be marked as failed!"
)
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
@ -349,23 +347,22 @@ async def async_request_openai_chat_completions(
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
("chat/completions", "profile")
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
assert api_url.endswith(("chat/completions", "profile")), (
"OpenAI Chat Completions API URL must end with 'chat/completions'."
)
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(
trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"messages": [
{
"role": "user",
"content": content
},
{"role": "user", "content": content},
],
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
@ -391,16 +388,16 @@ async def async_request_openai_chat_completions(
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
@ -414,13 +411,11 @@ async def async_request_openai_chat_completions(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
output.output_tokens = usage.get("completion_tokens")
most_recent_timestamp = timestamp
@ -446,25 +441,28 @@ async def async_request_openai_audio(
) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile
api_url = request_func_input.api_url
assert api_url.endswith(
("transcriptions", "translations"
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
assert api_url.endswith(("transcriptions", "translations")), (
"OpenAI Chat Completions API URL must end with 'transcriptions' "
)
"or `translations`."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(
trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"language": "en",
# Flattened due to multipart/form-data
"stream_include_usage": True,
"stream_continuous_usage_stats": True
"stream_continuous_usage_stats": True,
}
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
@ -479,9 +477,9 @@ async def async_request_openai_audio(
buffer.seek(0)
return buffer
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
form = aiohttp.FormData()
form.add_field('file', f, content_type='audio/wav')
form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items():
form.add_field(key, str(value))
@ -493,24 +491,22 @@ async def async_request_openai_audio(
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url,
data=form,
headers=headers) as response:
async with session.post(
url=api_url, data=form, headers=headers
) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get(
"content")
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
@ -519,12 +515,14 @@ async def async_request_openai_audio(
# Decoding phase
else:
output.itl.append(
timestamp - most_recent_timestamp)
timestamp - most_recent_timestamp
)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
"completion_tokens"
)
most_recent_timestamp = timestamp
@ -545,7 +543,7 @@ async def async_request_openai_audio(
def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true":
from modelscope import snapshot_download
from vllm.model_executor.model_loader.weight_utils import get_lock
@ -556,7 +554,8 @@ def get_model(pretrained_model_name_or_path: str) -> str:
model_path = snapshot_download(
model_id=pretrained_model_name_or_path,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
)
return model_path
return pretrained_model_name_or_path
@ -569,23 +568,23 @@ def get_tokenizer(
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if pretrained_model_name_or_path is not None and not os.path.exists(
pretrained_model_name_or_path):
pretrained_model_name_or_path = get_model(
pretrained_model_name_or_path)
pretrained_model_name_or_path
):
pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError(
"Cannot use the fast tokenizer in slow tokenizer mode.")
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if tokenizer_mode == "mistral":
try:
from vllm.transformers_utils.tokenizer import MistralTokenizer
except ImportError as e:
raise ImportError("MistralTokenizer requires vllm package.\n"
raise ImportError(
"MistralTokenizer requires vllm package.\n"
"Please install it with `pip install vllm` "
"to use mistral tokenizer mode.") from e
return MistralTokenizer.from_pretrained(
str(pretrained_model_name_or_path))
"to use mistral tokenizer mode."
) from e
return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path))
else:
return AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
@ -608,7 +607,7 @@ ASYNC_REQUEST_FUNCS = {
}
OPENAI_COMPATIBLE_BACKENDS = [
k for k, v in ASYNC_REQUEST_FUNCS.items()
if v in (async_request_openai_completions,
async_request_openai_chat_completions)
k
for k, v in ASYNC_REQUEST_FUNCS.items()
if v in (async_request_openai_completions, async_request_openai_chat_completions)
]

View File

@ -82,14 +82,12 @@ class BenchmarkDataset(ABC):
self.dataset_path = dataset_path
# Set the random seed, ensuring that a None value is replaced with the
# default seed.
self.random_seed = (random_seed
if random_seed is not None else self.DEFAULT_SEED)
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
self.data = None
def apply_multimodal_chat_transformation(
self,
prompt: str,
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
self, prompt: str, mm_content: Optional[MultiModalDataDict] = None
) -> list[dict]:
"""
Transform a prompt and optional multimodal content into a chat format.
This method is used for chat models that expect a specific conversation
@ -111,8 +109,7 @@ class BenchmarkDataset(ABC):
NotImplementedError: If a subclass does not implement this method.
"""
# TODO (jenniferzhao): add support for downloading data
raise NotImplementedError(
"load_data must be implemented in subclasses.")
raise NotImplementedError("load_data must be implemented in subclasses.")
def get_random_lora_request(
self,
@ -158,8 +155,9 @@ class BenchmarkDataset(ABC):
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
@abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase,
num_requests: int) -> list[SampleRequest]:
def sample(
self, tokenizer: PreTrainedTokenizerBase, num_requests: int
) -> list[SampleRequest]:
"""
Abstract method to generate sample requests from the dataset.
@ -177,8 +175,9 @@ class BenchmarkDataset(ABC):
"""
raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(self, requests: list[SampleRequest],
num_requests: int) -> None:
def maybe_oversample_requests(
self, requests: list[SampleRequest], num_requests: int
) -> None:
"""
Oversamples the list of requests if its size is less than the desired
number.
@ -189,11 +188,9 @@ class BenchmarkDataset(ABC):
"""
if len(requests) < num_requests:
random.seed(self.random_seed)
additional = random.choices(requests,
k=num_requests - len(requests))
additional = random.choices(requests, k=num_requests - len(requests))
requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.",
num_requests)
logger.info("Oversampled requests to reach %d total samples.", num_requests)
# -----------------------------------------------------------------------------
@ -218,14 +215,14 @@ def is_valid_sequence(
"""
# Check for invalid conditions
prompt_too_short = prompt_len < min_len
output_too_short = (not skip_min_output_len_check) and (output_len
< min_len)
output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
prompt_too_long = prompt_len > max_prompt_len
combined_too_long = (prompt_len + output_len) > max_total_len
# Return True if none of the invalid conditions are met
return not (prompt_too_short or output_too_short or prompt_too_long
or combined_too_long)
return not (
prompt_too_short or output_too_short or prompt_too_long or combined_too_long
)
@cache
@ -257,28 +254,28 @@ def process_image(image: Any) -> Mapping[str, Any]:
Raises:
ValueError: If the input is not a supported type.
"""
if isinstance(image, dict) and 'bytes' in image:
image = Image.open(BytesIO(image['bytes']))
if isinstance(image, dict) and "bytes" in image:
image = Image.open(BytesIO(image["bytes"]))
if isinstance(image, Image.Image):
image = image.convert("RGB")
with io.BytesIO() as image_data:
image.save(image_data, format="JPEG")
image_base64 = base64.b64encode(
image_data.getvalue()).decode("utf-8")
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
return {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
}
if isinstance(image, str):
image_url = (image if image.startswith(
("http://", "file://")) else f"file://{image}")
image_url = (
image if image.startswith(("http://", "file://")) else f"file://{image}"
)
return {"type": "image_url", "image_url": {"url": image_url}}
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
" or str or dictionary with raw image bytes.")
raise ValueError(
f"Invalid image input {image}. Must be a PIL.Image.Image"
" or str or dictionary with raw image bytes."
)
# -----------------------------------------------------------------------------
@ -318,8 +315,11 @@ class RandomDataset(BenchmarkDataset):
num_special_tokens = tokenizer.num_special_tokens_to_add()
real_input_len = input_len - num_special_tokens
prefix_token_ids = (np.random.randint(
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
prefix_token_ids = (
np.random.randint(0, vocab_size, size=prefix_len).tolist()
if prefix_len > 0
else []
)
# New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(real_input_len * (1 - range_ratio))
@ -329,21 +329,17 @@ class RandomDataset(BenchmarkDataset):
# Add logging for debugging
logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
logger.info("Sampling output_len from [%s, %s]", output_low,
output_high)
logger.info("Sampling output_len from [%s, %s]", output_low, output_high)
input_lens = np.random.randint(input_low,
input_high + 1,
size=num_requests)
output_lens = np.random.randint(output_low,
output_high + 1,
size=num_requests)
input_lens = np.random.randint(input_low, input_high + 1, size=num_requests)
output_lens = np.random.randint(output_low, output_high + 1, size=num_requests)
offsets = np.random.randint(0, vocab_size, size=num_requests)
requests = []
for i in range(num_requests):
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
vocab_size).tolist()
inner_seq = (
(offsets[i] + i + np.arange(input_lens[i])) % vocab_size
).tolist()
token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer.decode(token_sequence)
# After decoding the prompt we have to encode and decode it again.
@ -354,8 +350,9 @@ class RandomDataset(BenchmarkDataset):
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
re_encoded_sequence = tokenizer.encode(
prompt, add_special_tokens=False)[:input_lens[i]]
re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
: input_lens[i]
]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = prefix_len + int(input_lens[i])
requests.append(
@ -363,7 +360,8 @@ class RandomDataset(BenchmarkDataset):
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
))
)
)
return requests
@ -390,7 +388,8 @@ class ShareGPTDataset(BenchmarkDataset):
self.data = json.load(f)
# Filter entries with at least two conversation turns.
self.data = [
entry for entry in self.data
entry
for entry in self.data
if "conversations" in entry and len(entry["conversations"]) >= 2
]
random.seed(self.random_seed)
@ -416,27 +415,28 @@ class ShareGPTDataset(BenchmarkDataset):
)
lora_request, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
)
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
new_output_len = (len(completion_ids)
if output_len is None else output_len)
if not is_valid_sequence(prompt_len,
new_output_len = len(completion_ids) if output_len is None else output_len
if not is_valid_sequence(
prompt_len,
new_output_len,
skip_min_output_len_check=output_len
is not None):
skip_min_output_len_check=output_len is not None,
):
continue
if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation(
prompt, None)
prompt = self.apply_multimodal_chat_transformation(prompt, None)
samples.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=new_output_len,
lora_request=lora_request,
))
)
)
self.maybe_oversample_requests(samples, num_requests)
return samples
@ -482,20 +482,20 @@ class SonnetDataset(BenchmarkDataset):
) -> list:
# Calculate average token length for a poem line.
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
avg_len = sum(len(tokens)
for tokens in tokenized_lines) / len(tokenized_lines)
avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines)
# Build the base prompt.
base_prompt = "Pick as many lines as you can from these poem lines:\n"
base_msg = [{"role": "user", "content": base_prompt}]
base_fmt = tokenizer.apply_chat_template(base_msg,
add_generation_prompt=True,
tokenize=False)
base_fmt = tokenizer.apply_chat_template(
base_msg, add_generation_prompt=True, tokenize=False
)
base_offset = len(tokenizer(base_fmt).input_ids)
if input_len <= base_offset:
raise ValueError(
f"'input_len' must be higher than the base prompt length "
f"({base_offset}).")
f"({base_offset})."
)
# Determine how many poem lines to use.
num_input_lines = round((input_len - base_offset) / avg_len)
@ -504,21 +504,23 @@ class SonnetDataset(BenchmarkDataset):
samples = []
while len(samples) < num_requests:
extra_lines = random.choices(self.data,
k=num_input_lines - num_prefix_lines)
extra_lines = random.choices(
self.data, k=num_input_lines - num_prefix_lines
)
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
msg = [{"role": "user", "content": prompt}]
prompt_formatted = tokenizer.apply_chat_template(
msg, add_generation_prompt=True, tokenize=False)
msg, add_generation_prompt=True, tokenize=False
)
prompt_len = len(tokenizer(prompt_formatted).input_ids)
if prompt_len <= input_len:
samples.append(
SampleRequest(
prompt=prompt_formatted
if return_prompt_formatted else prompt,
prompt=prompt_formatted if return_prompt_formatted else prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
)
)
return samples
@ -538,7 +540,9 @@ class BurstGPTDataset(BenchmarkDataset):
super().__init__(**kwargs)
self.load_data()
def load_data(self, ):
def load_data(
self,
):
if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.")
@ -552,8 +556,7 @@ class BurstGPTDataset(BenchmarkDataset):
def _sample_loaded_data(self, num_requests: int) -> list:
if num_requests <= len(self.data):
data = self.data.sample(n=num_requests,
random_state=self.random_seed)
data = self.data.sample(n=num_requests, random_state=self.random_seed)
else:
data = self.data.sample(
n=num_requests,
@ -577,7 +580,8 @@ class BurstGPTDataset(BenchmarkDataset):
input_len = int(data[i][2])
output_len = int(data[i][3])
lora_req, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
)
vocab_size = tokenizer.vocab_size
# Generate a synthetic prompt: a list of token IDs computed as (i +
# j) modulo vocab_size.
@ -589,7 +593,8 @@ class BurstGPTDataset(BenchmarkDataset):
prompt_len=input_len,
expected_output_len=output_len,
lora_request=lora_req,
))
)
)
return samples
@ -632,20 +637,23 @@ class HuggingFaceDataset(BenchmarkDataset):
class ConversationDataset(HuggingFaceDataset):
"""Dataset for conversation data with multimodal support."""
SUPPORTED_DATASET_PATHS = {
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
"lmms-lab/LLaVA-OneVision-Data",
"Aeala/ShareGPT_Vicuna_unfiltered",
}
IS_MULTIMODAL = True
def sample(self,
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
**kwargs,
) -> list:
# Filter examples with at least 2 conversations
filtered_data = self.data.filter(
lambda x: len(x["conversations"]) >= 2)
filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
sampled_requests = []
dynamic_output = output_len is None
@ -661,24 +669,22 @@ class ConversationDataset(HuggingFaceDataset):
completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0
if dynamic_output and not is_valid_sequence(
prompt_len, completion_len):
if dynamic_output and not is_valid_sequence(prompt_len, completion_len):
continue
mm_content = process_image(
item["image"]) if "image" in item else None
mm_content = process_image(item["image"]) if "image" in item else None
if enable_multimodal_chat:
# Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the
# actual prompt len and output len
prompt = self.apply_multimodal_chat_transformation(
prompt, mm_content)
prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
))
)
)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
@ -695,10 +701,8 @@ class VisionArenaDataset(HuggingFaceDataset):
DEFAULT_OUTPUT_LEN = 128
SUPPORTED_DATASET_PATHS = {
"lmarena-ai/VisionArena-Chat":
lambda x: x["conversation"][0][0]["content"],
"lmarena-ai/vision-arena-bench-v0.1":
lambda x: x["turns"][0][0]["content"]
"lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"],
"lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"],
}
IS_MULTIMODAL = True
@ -710,16 +714,14 @@ class VisionArenaDataset(HuggingFaceDataset):
enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
if parser_fn is None:
raise ValueError(
f"Unsupported dataset path: {self.dataset_path}")
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
prompt = parser_fn(item)
mm_content = process_image(item["images"][0])
prompt_len = len(tokenizer(prompt).input_ids)
@ -727,15 +729,15 @@ class VisionArenaDataset(HuggingFaceDataset):
# Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the
# actual prompt len
prompt = self.apply_multimodal_chat_transformation(
prompt, mm_content)
prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
))
)
)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
@ -760,14 +762,15 @@ class InstructCoderDataset(HuggingFaceDataset):
"likaixin/InstructCoder",
}
def sample(self,
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
**kwargs,
) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
@ -779,7 +782,8 @@ class InstructCoderDataset(HuggingFaceDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
)
)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
@ -804,28 +808,28 @@ class MTBenchDataset(HuggingFaceDataset):
"philschmid/mt-bench",
}
def sample(self,
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
**kwargs,
) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt = item['turns'][0]
prompt = item["turns"][0]
# apply template
prompt = tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False)
tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
@ -833,7 +837,8 @@ class MTBenchDataset(HuggingFaceDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
)
)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
@ -847,23 +852,27 @@ class AIMODataset(HuggingFaceDataset):
"""
Dataset class for processing a AIMO dataset with reasoning questions.
"""
SUPPORTED_DATASET_PATHS = {
"AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5",
"AI-MO/NuminaMath-CoT"
"AI-MO/aimo-validation-aime",
"AI-MO/NuminaMath-1.5",
"AI-MO/NuminaMath-CoT",
}
def sample(self,
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
**kwargs) -> list:
**kwargs,
) -> list:
sampled_requests = []
dynamic_output = output_len is None
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt, completion = item['problem'], item["solution"]
prompt, completion = item["problem"], item["solution"]
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
@ -871,10 +880,9 @@ class AIMODataset(HuggingFaceDataset):
completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0
if dynamic_output and not is_valid_sequence(prompt_len,
completion_len,
max_prompt_len=2048,
max_total_len=32000):
if dynamic_output and not is_valid_sequence(
prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000
):
continue
sampled_requests.append(
SampleRequest(
@ -882,7 +890,8 @@ class AIMODataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=None,
))
)
)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
@ -909,8 +918,8 @@ You are a code completion assistant and your task is to analyze user edits and t
def _format_zeta_prompt(
sample: dict,
original_start_marker: str = "<|editable_region_start|>") -> dict:
sample: dict, original_start_marker: str = "<|editable_region_start|>"
) -> dict:
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
This function formats examples from the NEP dataset
@ -953,10 +962,8 @@ class NextEditPredictionDataset(HuggingFaceDataset):
"zed-industries/zeta": _format_zeta_prompt,
}
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
**kwargs):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
self.dataset_path)
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = []
@ -967,8 +974,10 @@ class NextEditPredictionDataset(HuggingFaceDataset):
prompt=sample["prompt"],
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids),
))
tokenizer(sample["expected_output"]).input_ids
),
)
)
if len(samples) >= num_requests:
break
self.maybe_oversample_requests(samples, num_requests)
@ -998,17 +1007,21 @@ class ASRDataset(HuggingFaceDataset):
+----------------+----------------------------------------+--------------------------+-----------------------------+
""" # noqa: E501
SUPPORTED_DATASET_PATHS = {
"openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium",
"edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech"
"openslr/librispeech_asr",
"facebook/voxpopuli",
"LIUM/tedlium",
"edinburghcstr/ami",
"speechcolab/gigaspeech",
"kensho/spgispeech",
}
DEFAULT_OUTPUT_LEN = 128
IS_MULTIMODAL = True
# TODO Whisper-specific. Abstract interface when more models are supported.
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\
"<|notimestamps|>"
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
skip_long_audios: bool = True
def sample(
@ -1019,8 +1032,8 @@ class ASRDataset(HuggingFaceDataset):
**kwargs,
) -> list:
import librosa
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = []
@ -1043,10 +1056,14 @@ class ASRDataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
))
)
)
if skipped:
logger.warning("%d samples discarded from dataset due to" \
" their length being greater than" \
" what Whisper supports.", skipped)
logger.warning(
"%d samples discarded from dataset due to"
" their length being greater than"
" what Whisper supports.",
skipped,
)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests

View File

@ -11,9 +11,9 @@ from typing import Any, Optional
import numpy as np
import torch
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm import tqdm
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
@ -21,13 +21,14 @@ from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: dict[str, Any]) -> None:
def save_to_pytorch_benchmark_format(
args: argparse.Namespace, results: dict[str, Any]
) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={"latency": results["latencies"]},
extra_info={k: results[k]
for k in ["avg_latency", "percentiles"]})
extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
)
if pt_records:
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records)
@ -42,9 +43,11 @@ def main(args: argparse.Namespace):
# the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args))
assert llm.llm_engine.model_config.max_model_len >= (
args.input_len +
args.output_len), ("Please ensure that max_model_len is greater than"
" the sum of input_len and output_len.")
args.input_len + args.output_len
), (
"Please ensure that max_model_len is greater than"
" the sum of input_len and output_len."
)
sampling_params = SamplingParams(
n=args.n,
@ -55,18 +58,16 @@ def main(args: argparse.Namespace):
detokenize=not args.disable_detokenize,
)
print(sampling_params)
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_prompts: list[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
dummy_prompt_token_ids = np.random.randint(
10000, size=(args.batch_size, args.input_len)
)
dummy_prompts: list[PromptType] = [
{"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
]
def llm_generate():
if not args.use_beam_search:
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
else:
llm.beam_search(
dummy_prompts,
@ -85,7 +86,8 @@ def main(args: argparse.Namespace):
torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir)),
str(profile_dir)
),
) as p:
llm_generate()
print(p.key_averages().table(sort_by="self_cuda_time_total"))
@ -103,8 +105,9 @@ def main(args: argparse.Namespace):
if args.profile:
profile_dir = args.profile_result_dir
if not profile_dir:
profile_dir = (Path(".") / "vllm_benchmark_result" /
f"latency_result_{time.time()}")
profile_dir = (
Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
)
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir)
return
@ -135,7 +138,8 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark the latency of processing a single batch of "
"requests till completion.")
"requests till completion."
)
parser.add_argument("--input-len", type=int, default=32)
parser.add_argument("--output-len", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=8)
@ -152,10 +156,9 @@ if __name__ == "__main__":
default=10,
help="Number of iterations to run for warmup.",
)
parser.add_argument("--num-iters",
type=int,
default=30,
help="Number of iterations to run.")
parser.add_argument(
"--num-iters", type=int, default=30, help="Number of iterations to run."
)
parser.add_argument(
"--profile",
action="store_true",
@ -165,8 +168,10 @@ if __name__ == "__main__":
"--profile-result-dir",
type=str,
default=None,
help=("path to save the pytorch profiler output. Can be visualized "
"with ui.perfetto.dev or Tensorboard."),
help=(
"path to save the pytorch profiler output. Can be visualized "
"with ui.perfetto.dev or Tensorboard."
),
)
parser.add_argument(
"--output-json",
@ -177,8 +182,10 @@ if __name__ == "__main__":
parser.add_argument(
"--disable-detokenize",
action="store_true",
help=("Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"),
help=(
"Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
)
parser = EngineArgs.add_cli_args(parser)

View File

@ -86,20 +86,21 @@ def repeat_prompts(prompts, repeat_count, mode: str):
ValueError: If an invalid mode is provided.
"""
print("Repeat mode: ", mode)
if mode == 'random':
if mode == "random":
repeated_prompts = prompts * repeat_count
random.shuffle(repeated_prompts)
return repeated_prompts
elif mode == 'tile':
elif mode == "tile":
return prompts * repeat_count
elif mode == 'interleave':
elif mode == "interleave":
repeated_prompts = []
for prompt in prompts:
repeated_prompts.extend([prompt] * repeat_count)
return repeated_prompts
else:
raise ValueError(f"Invalid mode: {mode}, only support "
"'random', 'tile', 'interleave'")
raise ValueError(
f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'"
)
def main(args):
@ -109,16 +110,16 @@ def main(args):
# we append the document id at the beginning to avoid any of the document
# being the prefix of other documents
prompts = [
str(i) + ' '.join(['hi'] * args.document_length)
str(i) + " ".join(["hi"] * args.document_length)
for i in range(args.num_documents)
]
prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)
warmup_prompts = [
"This is warm up request " + str(i) + \
' '.join(['hi'] * args.document_length)
for i in range(args.num_documents)]
"This is warm up request " + str(i) + " ".join(["hi"] * args.document_length)
for i in range(args.num_documents)
]
# Create the LLM engine
engine_args = EngineArgs.from_cli_args(args)
@ -142,42 +143,52 @@ def main(args):
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description=
'Benchmark the performance with or without automatic prefix caching.')
description="Benchmark the performance with or "
"without automatic prefix caching."
)
parser.add_argument(
'--document-length',
"--document-length",
type=int,
# Roughly the number of tokens for a system paper,
# excluding images
default=20000,
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')
help="Range of input lengths for sampling prompts, "
'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument('--num-documents',
parser.add_argument(
"--num-documents",
type=int,
default=8,
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')
help="Range of input lengths for sampling prompts, "
'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument("--output-len", type=int, default=10)
parser.add_argument('--repeat-count',
parser.add_argument(
"--repeat-count",
type=int,
default=2,
help='Number of times to repeat each prompt')
help="Number of times to repeat each prompt",
)
parser.add_argument("--repeat-mode",
parser.add_argument(
"--repeat-mode",
type=str,
default='random',
help='The mode to repeat prompts. The supported '
default="random",
help="The mode to repeat prompts. The supported "
'modes are "random", "tile", and "interleave". '
'See repeat_prompts() in the source code for details.')
"See repeat_prompts() in the source code for details.",
)
parser.add_argument("--shuffle-seed",
parser.add_argument(
"--shuffle-seed",
type=int,
default=0,
help='Random seed when the repeat mode is "random"')
help='Random seed when the repeat mode is "random"',
)
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()

View File

@ -63,8 +63,7 @@ class Request:
output_len: int
def sample_tokens(tokenizer: PreTrainedTokenizerBase,
length: int) -> list[int]:
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
vocab = tokenizer.get_vocab()
all_special_ids = set(tokenizer.all_special_ids)
@ -91,8 +90,10 @@ def sample_requests_from_dataset(
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset.
random.shuffle(dataset)
@ -113,8 +114,9 @@ def sample_requests_from_dataset(
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = (len(completion_token_ids)
if fixed_output_len is None else fixed_output_len)
output_len = (
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if min_len <= prompt_len <= max_len:
filtered_requests.append(Request(prompt, prompt_len, output_len))
@ -128,27 +130,27 @@ def sample_requests_from_random(
fixed_output_len: Optional[int],
prefix_len: int,
) -> list[Request]:
requests = []
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
min_len, max_len = input_length_range
for i in range(num_requests):
unique_part_token_ids = sample_tokens(
tokenizer,
random.randint(min_len - prefix_len, max_len - prefix_len))
tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len)
)
prompt_token_ids = prefix_token_ids + unique_part_token_ids
prompt = tokenizer.decode(prompt_token_ids)
prompt_len = len(prompt_token_ids)
assert (min_len <= prompt_len <= max_len
), f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
assert min_len <= prompt_len <= max_len, (
f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
)
requests.append(Request(prompt, prompt_len, fixed_output_len))
return requests
def repeat_and_sort_requests(requests: list[Request],
repeat_count: int,
sort: bool = False) -> list[str]:
def repeat_and_sort_requests(
requests: list[Request], repeat_count: int, sort: bool = False
) -> list[str]:
repeated_requests = requests * repeat_count
if sort:
repeated_requests.sort(key=lambda x: x[1])
@ -159,14 +161,14 @@ def repeat_and_sort_requests(requests: list[Request],
def main(args):
tokenizer = get_tokenizer(args.model, trust_remote_code=True)
input_length_range = tuple(map(int, args.input_length_range.split(':')))
input_length_range = tuple(map(int, args.input_length_range.split(":")))
random.seed(args.seed)
if args.dataset_path is not None:
if args.prefix_len > 0:
raise ValueError("prefix-len is not supported when "
"dataset-path is provided.")
print(f"Start to sample {args.num_prompts} prompts "
f"from {args.dataset_path}")
raise ValueError(
"prefix-len is not supported when dataset-path is provided."
)
print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}")
filtered_requests = sample_requests_from_dataset(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
@ -196,14 +198,16 @@ def main(args):
llm = LLM(**dataclasses.asdict(engine_args))
sampling_params = SamplingParams(temperature=0,
sampling_params = SamplingParams(
temperature=0,
max_tokens=args.output_len,
detokenize=not args.disable_detokenize)
detokenize=not args.disable_detokenize,
)
print("Testing filtered requests")
prompts = repeat_and_sort_requests(filtered_requests,
repeat_count=args.repeat_count,
sort=args.sort)
prompts = repeat_and_sort_requests(
filtered_requests, repeat_count=args.repeat_count, sort=args.sort
)
print("------start generating------")
test_prefix(
@ -215,29 +219,35 @@ def main(args):
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description=
'Benchmark the performance with or without automatic prefix caching.')
parser.add_argument("--dataset-path",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--num-prompts',
description="Benchmark the performance with or without "
"automatic prefix caching."
)
parser.add_argument(
"--dataset-path", type=str, default=None, help="Path to the dataset."
)
parser.add_argument("--output-len", type=int, default=10)
parser.add_argument(
"--num-prompts",
type=int,
required=True,
help="Number of the prompts sampled from dataset")
parser.add_argument('--repeat-count',
help="Number of the prompts sampled from dataset",
)
parser.add_argument(
"--repeat-count",
type=int,
default=1,
help='Number of times to repeat each prompt')
parser.add_argument('--sort',
action='store_true',
help='Sort prompts by input length')
parser.add_argument('--input-length-range',
help="Number of times to repeat each prompt",
)
parser.add_argument(
"--sort", action="store_true", help="Sort prompts by input length"
)
parser.add_argument(
"--input-length-range",
type=str,
required=True,
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')
help="Range of input lengths for sampling prompts,"
'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument(
"--prefix-len",
type=int,
@ -248,10 +258,12 @@ if __name__ == "__main__":
"when dataset-path is not provided.",
)
parser.add_argument(
'--disable-detokenize',
action='store_true',
help=("Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"),
"--disable-detokenize",
action="store_true",
help=(
"Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
)
parser = EngineArgs.add_cli_args(parser)

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Benchmark offline prioritization."""
import argparse
import dataclasses
import json
@ -13,7 +14,7 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser
#Select a equi-probable random priority
# Select a equi-probable random priority
def get_random_flag():
return 0 if random.random() < 0.5 else 1
@ -33,8 +34,10 @@ def sample_requests(
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset.
random.shuffle(dataset)
@ -51,8 +54,9 @@ def sample_requests(
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
output_len = (
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
@ -74,13 +78,16 @@ def run_vllm(
disable_detokenize: bool = False,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
for request in requests), (
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of"
" input_len and output_len for all requests.")
" input_len and output_len for all requests."
)
# Add the requests to the engine.
prompts = []
@ -97,7 +104,8 @@ def run_vllm(
ignore_eos=True,
max_tokens=output_len,
detokenize=not disable_detokenize,
))
)
)
start = time.perf_counter()
llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
@ -111,26 +119,33 @@ def main(args: argparse.Namespace):
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
args.tokenizer, trust_remote_code=args.trust_remote_code
)
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len,
get_random_flag()) for _ in range(args.num_prompts)]
requests = [
(prompt, args.input_len, args.output_len, get_random_flag())
for _ in range(args.num_prompts)
]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
requests = sample_requests(
args.dataset, args.num_prompts, tokenizer, args.output_len
)
if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.n,
EngineArgs.from_cli_args(args),
args.disable_detokenize)
elapsed_time = run_vllm(
requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len, priority in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
total_num_tokens = sum(
prompt_len + output_len for _, prompt_len, output_len, priority in requests
)
print(
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s"
)
# Output JSON results if specified
if args.output_json:
@ -147,41 +162,44 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii"],
default="vllm")
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument("--input-len",
parser.add_argument(
"--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm"
)
parser.add_argument(
"--dataset", type=str, default=None, help="Path to the dataset."
)
parser.add_argument(
"--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
help="Input prompt length for each request",
)
parser.add_argument(
"--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--num-prompts",
type=int,
default=200,
help="Number of prompts to process.")
"output length from the dataset.",
)
parser.add_argument(
'--output-json',
"--n", type=int, default=1, help="Number of generated sequences per prompt."
)
parser.add_argument(
"--num-prompts", type=int, default=200, help="Number of prompts to process."
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
help="Path to save the throughput results in JSON format.",
)
parser.add_argument(
'--disable-detokenize',
action='store_true',
help=("Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"),
"--disable-detokenize",
action="store_true",
help=(
"Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
)
parser = EngineArgs.add_cli_args(parser)

View File

@ -20,6 +20,7 @@ On the client side, run:
--endpoint /generate_stream
to the end of the command above.
"""
import argparse
import asyncio
import gc
@ -34,12 +35,16 @@ from datetime import datetime
from typing import Any, Optional
import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
RequestFuncOutput)
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
from backend_request_func import (
ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS,
RequestFuncInput,
RequestFuncOutput,
)
try:
from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError:
@ -50,12 +55,21 @@ try:
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
ConversationDataset, HuggingFaceDataset,
InstructCoderDataset, MTBenchDataset,
NextEditPredictionDataset, RandomDataset,
SampleRequest, ShareGPTDataset, SonnetDataset,
VisionArenaDataset)
from benchmark_dataset import (
AIMODataset,
ASRDataset,
BurstGPTDataset,
ConversationDataset,
HuggingFaceDataset,
InstructCoderDataset,
MTBenchDataset,
NextEditPredictionDataset,
RandomDataset,
SampleRequest,
ShareGPTDataset,
SonnetDataset,
VisionArenaDataset,
)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@ -118,7 +132,8 @@ async def get_request(
# Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}.")
f"A positive burstiness factor is expected, but given {burstiness}."
)
theta = 1.0 / (request_rate * burstiness)
for request in input_requests:
@ -164,8 +179,10 @@ def calculate_metrics(
# bundled together
# Note : this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text,
add_special_tokens=False).input_ids)
tokenizer(
outputs[i].generated_text, add_special_tokens=False
).input_ids
)
actual_output_lens.append(output_len)
total_input += input_requests[i].prompt_len
tpot = 0
@ -188,16 +205,19 @@ def calculate_metrics(
if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts)
slo_values.append(goodput_config_dict["ttft"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(
goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots)
slo_values.append(goodput_config_dict["tpot"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(
goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
if "e2el" in goodput_config_dict:
valid_metrics.append(e2els)
slo_values.append(goodput_config_dict["e2el"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(
goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
@ -208,7 +228,8 @@ def calculate_metrics(
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
stacklevel=2,
)
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
@ -217,27 +238,31 @@ def calculate_metrics(
request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_ttft_ms=np.mean(ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
mean_ttft_ms=np.mean(ttfts or 0)
* 1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
for p in selected_percentiles],
percentiles_ttft_ms=[
(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
],
mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
for p in selected_percentiles],
percentiles_tpot_ms=[
(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
],
mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
for p in selected_percentiles],
percentiles_itl_ms=[
(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
],
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles],
percentiles_e2el_ms=[
(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
],
)
return metrics, actual_output_lens
@ -270,10 +295,12 @@ async def benchmark(
raise ValueError(f"Unknown backend: {backend}")
print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len, test_mm_content = \
input_requests[0].prompt, input_requests[0].prompt_len, \
input_requests[0].expected_output_len, \
input_requests[0].multi_modal_data
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
input_requests[0].prompt,
input_requests[0].prompt_len,
input_requests[0].expected_output_len,
input_requests[0].multi_modal_data,
)
assert test_mm_content is None or isinstance(test_mm_content, dict)
test_input = RequestFuncInput(
@ -293,19 +320,21 @@ async def benchmark(
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}")
f"are correctly specified. Error: {test_output.error}"
)
else:
print("Initial test run completed. Starting main benchmark run...")
if lora_modules:
# For each input request, choose a LoRA module at random.
lora_modules = iter(
[random.choice(lora_modules) \
for _ in range(len(input_requests))])
[random.choice(lora_modules) for _ in range(len(input_requests))]
)
if profile:
print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id,
profile_input = RequestFuncInput(
model=model_id,
model_name=model_name,
prompt=test_prompt,
api_url=base_url + "/start_profile",
@ -314,15 +343,13 @@ async def benchmark(
logprobs=logprobs,
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos,
extra_body=extra_body)
extra_body=extra_body,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler started")
if burstiness == 1.0:
distribution = "Poisson process"
else:
distribution = "Gamma distribution"
distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
print(f"Traffic request rate: {request_rate}")
print(f"Burstiness factor: {burstiness} ({distribution})")
@ -334,29 +361,30 @@ async def benchmark(
# and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext())
semaphore = (asyncio.Semaphore(max_concurrency)
if max_concurrency else None)
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def limited_request_func(request_func_input, pbar):
if semaphore is None:
return await request_func(request_func_input=request_func_input,
pbar=pbar)
return await request_func(request_func_input=request_func_input, pbar=pbar)
async with semaphore:
return await request_func(request_func_input=request_func_input,
pbar=pbar)
return await request_func(request_func_input=request_func_input, pbar=pbar)
benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness):
prompt, prompt_len, output_len, mm_content = request.prompt, \
request.prompt_len, request.expected_output_len, \
request.multi_modal_data
prompt, prompt_len, output_len, mm_content = (
request.prompt,
request.prompt_len,
request.expected_output_len,
request.multi_modal_data,
)
req_model_id, req_model_name = model_id, model_name
if lora_modules:
req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module
request_func_input = RequestFuncInput(model=req_model_id,
request_func_input = RequestFuncInput(
model=req_model_id,
model_name=req_model_name,
prompt=prompt,
api_url=api_url,
@ -365,11 +393,13 @@ async def benchmark(
logprobs=logprobs,
multi_modal_content=mm_content,
ignore_eos=ignore_eos,
extra_body=extra_body)
extra_body=extra_body,
)
tasks.append(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input,
pbar=pbar)))
limited_request_func(request_func_input=request_func_input, pbar=pbar)
)
)
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile:
@ -401,22 +431,32 @@ async def benchmark(
goodput_config_dict=goodput_config_dict,
)
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
print(
"{:<40} {:<10.2f}".format(
"Request throughput (req/s):", metrics.request_throughput
)
)
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
print(
"{:<40} {:<10.2f}".format(
"Request goodput (req/s):", metrics.request_goodput
)
)
print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Total Token throughput (tok/s):", metrics.total_token_throughput
)
)
result = {
"duration": benchmark_duration,
@ -424,8 +464,7 @@ async def benchmark(
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"request_goodput:":
metrics.request_goodput if goodput_config_dict else None,
"request_goodput:": metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
@ -448,29 +487,35 @@ async def benchmark(
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
print("{:<40} {:<10.2f}".format(
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms")))
getattr(metrics, f"median_{metric_attribute_name}_ms"),
)
)
result[f"mean_{metric_attribute_name}_ms"] = getattr(
metrics, f"mean_{metric_attribute_name}_ms")
metrics, f"mean_{metric_attribute_name}_ms"
)
result[f"median_{metric_attribute_name}_ms"] = getattr(
metrics, f"median_{metric_attribute_name}_ms")
metrics, f"median_{metric_attribute_name}_ms"
)
result[f"std_{metric_attribute_name}_ms"] = getattr(
metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}_ms"):
metrics, f"std_{metric_attribute_name}_ms"
)
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
value))
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
@ -490,12 +535,14 @@ def check_goodput_args(args):
raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of "
f"{str(VALID_NAMES)}. ")
f"{str(VALID_NAMES)}. "
)
if slo_val < 0:
raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be "
"non-negative.")
"non-negative."
)
return goodput_config_dict
@ -508,31 +555,42 @@ def parse_goodput(slo_pairs):
except ValueError as err:
raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" "
'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err
"number in milliseconds."
) from err
return goodput_config_dict
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: dict[str, Any],
file_name: str) -> None:
def save_to_pytorch_benchmark_format(
args: argparse.Namespace, results: dict[str, Any], file_name: str
) -> None:
metrics = [
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
"mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
"median_ttft_ms",
"mean_ttft_ms",
"std_ttft_ms",
"p99_ttft_ms",
"mean_tpot_ms",
"median_tpot_ms",
"std_tpot_ms",
"p99_tpot_ms",
"median_itl_ms",
"mean_itl_ms",
"std_itl_ms",
"p99_itl_ms",
]
# These raw data might be useful, but they are rather big. They can be added
# later if needed
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={k: [results[k]]
for k in metrics},
metrics={k: [results[k]] for k in metrics},
extra_info={
k: results[k]
for k in results if k not in metrics and k not in ignored_metrics
})
for k in results
if k not in metrics and k not in ignored_metrics
},
)
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
@ -557,34 +615,42 @@ def main(args: argparse.Namespace):
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
base_url = f"http://{args.host}:{args.port}"
tokenizer = get_tokenizer(tokenizer_id,
tokenizer = get_tokenizer(
tokenizer_id,
tokenizer_mode=tokenizer_mode,
trust_remote_code=args.trust_remote_code)
trust_remote_code=args.trust_remote_code,
)
if args.dataset_name is None:
raise ValueError(
"Please specify '--dataset-name' and the corresponding "
"'--dataset-path' if required.")
"'--dataset-path' if required."
)
if args.dataset_name == "sonnet":
dataset = SonnetDataset(dataset_path=args.dataset_path)
# For the "sonnet" dataset, formatting depends on the backend.
if args.backend == "openai-chat":
input_requests = dataset.sample(num_requests=args.num_prompts,
input_requests = dataset.sample(
num_requests=args.num_prompts,
input_len=args.sonnet_input_len,
output_len=args.sonnet_output_len,
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
return_prompt_formatted=False)
return_prompt_formatted=False,
)
else:
assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset.")
input_requests = dataset.sample(num_requests=args.num_prompts,
"Tokenizer/model must have chat template for sonnet dataset."
)
input_requests = dataset.sample(
num_requests=args.num_prompts,
input_len=args.sonnet_input_len,
output_len=args.sonnet_output_len,
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
return_prompt_formatted=True)
return_prompt_formatted=True,
)
elif args.dataset_name == "hf":
# all following datasets are implemented from the
@ -611,23 +677,30 @@ def main(args: argparse.Namespace):
dataset_class = ASRDataset
args.hf_split = "train"
else:
supported_datasets = set([
dataset_name for cls in HuggingFaceDataset.__subclasses__()
supported_datasets = set(
[
dataset_name
for cls in HuggingFaceDataset.__subclasses__()
for dataset_name in cls.SUPPORTED_DATASET_PATHS
])
]
)
raise ValueError(
f"Unsupported dataset path: {args.dataset_path}. "
"Huggingface dataset only supports dataset_path"
f" from one of following: {supported_datasets}. "
"Please consider contributing if you would "
"like to add support for additional dataset formats.")
"like to add support for additional dataset formats."
)
if (dataset_class.IS_MULTIMODAL and backend not in \
["openai-chat", "openai-audio"]):
if dataset_class.IS_MULTIMODAL and backend not in [
"openai-chat",
"openai-audio",
]:
# multi-modal benchmark is only available on OpenAI Chat backend.
raise ValueError(
"Multi-modal content is only supported on 'openai-chat' and " \
"'openai-audio' backend.")
"Multi-modal content is only supported on 'openai-chat' and "
"'openai-audio' backend."
)
input_requests = dataset_class(
dataset_path=args.dataset_path,
dataset_subset=args.hf_subset,
@ -642,26 +715,24 @@ def main(args: argparse.Namespace):
else:
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"sharegpt":
lambda: ShareGPTDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
"sharegpt": lambda: ShareGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
),
"burstgpt":
lambda: BurstGPTDataset(random_seed=args.seed,
dataset_path=args.dataset_path).
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
"random":
lambda: RandomDataset(dataset_path=args.dataset_path).sample(
"burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(tokenizer=tokenizer, num_requests=args.num_prompts),
"random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
prefix_len=args.random_prefix_len,
input_len=args.random_input_len,
output_len=args.random_output_len,
range_ratio=args.random_range_ratio,
)
),
}
try:
@ -677,15 +748,16 @@ def main(args: argparse.Namespace):
"top_p": args.top_p,
"top_k": args.top_k,
"min_p": args.min_p,
"temperature": args.temperature
}.items() if v is not None
"temperature": args.temperature,
}.items()
if v is not None
}
# Sampling parameters are only supported by openai-compatible backend.
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
raise ValueError(
"Sampling parameters are only supported by openai-compatible "
"backends.")
"Sampling parameters are only supported by openai-compatible backends."
)
if "temperature" not in sampling_params:
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
@ -709,15 +781,14 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[
float(p) for p in args.metric_percentiles.split(",")
],
selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
extra_body=sampling_params,
))
)
)
# Save config and results to json
if args.save_result or args.append_result:
@ -742,8 +813,9 @@ def main(args: argparse.Namespace):
"Invalid metadata format. Please use KEY=VALUE format."
)
# Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate
< float("inf") else "inf")
result_json["request_rate"] = (
args.request_rate if args.request_rate < float("inf") else "inf"
)
result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency
@ -753,24 +825,31 @@ def main(args: argparse.Namespace):
if not args.save_detailed:
# Remove fields with too many data points
for field in [
"input_lens", "output_lens", "ttfts", "itls",
"generated_texts", "errors"
"input_lens",
"output_lens",
"ttfts",
"itls",
"generated_texts",
"errors",
]:
if field in result_json:
del result_json[field]
# Save to file
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None else "")
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
max_concurrency_str = (
f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None
else ""
)
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
file_name = os.path.join(args.result_dir, file_name)
with open(file_name,
mode="a+" if args.append_result else "w",
encoding='utf-8') as outfile:
with open(
file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
) as outfile:
# Append a newline.
if args.append_result and outfile.tell() != 0:
outfile.write("\n")
@ -780,7 +859,8 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.")
description="Benchmark the online serving throughput."
)
parser.add_argument(
"--backend",
type=str,
@ -809,11 +889,13 @@ if __name__ == "__main__":
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument("--dataset-path",
parser.add_argument(
"--dataset-path",
type=str,
default=None,
help="Path to the sharegpt/sonnet dataset. "
"Or the huggingface dataset ID if using HF dataset.")
"Or the huggingface dataset ID if using HF dataset.",
)
parser.add_argument(
"--max-concurrency",
type=int,
@ -825,7 +907,8 @@ if __name__ == "__main__":
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.")
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument(
"--model",
@ -836,8 +919,7 @@ if __name__ == "__main__":
parser.add_argument(
"--tokenizer",
type=str,
help=
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
@ -850,11 +932,13 @@ if __name__ == "__main__":
"--logprobs",
type=int,
default=None,
help=("Number of logprobs-per-token to compute & return as part of "
help=(
"Number of logprobs-per-token to compute & return as part of "
"the request. If unspecified, then either (1) if beam search "
"is disabled, no logprobs are computed & a single dummy "
"logprob is returned for each token; or (2) if beam search "
"is enabled 1 logprob per token is computed"),
"is enabled 1 logprob per token is computed"
),
)
parser.add_argument(
"--request-rate",
@ -938,35 +1022,38 @@ if __name__ == "__main__":
"--ignore-eos",
action="store_true",
help="Set ignore_eos flag when sending the benchmark request."
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
)
parser.add_argument(
"--percentile-metrics",
type=str,
default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. "
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
"Default value is \"ttft,tpot,itl\".")
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
'Default value is "ttft,tpot,itl".',
)
parser.add_argument(
"--metric-percentiles",
type=str,
default="99",
help="Comma-separated list of percentiles for selected metrics. "
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
"Default value is \"99\". "
"Use \"--percentile-metrics\" to select metrics.",
'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
'Default value is "99". '
'Use "--percentile-metrics" to select metrics.',
)
parser.add_argument(
"--goodput",
nargs="+",
required=False,
help="Specify service level objectives for goodput as \"KEY:VALUE\" "
help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in "
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are "
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
'"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
# group for dataset specific arguments
sonnet_group = parser.add_argument_group("sonnet dataset options")
@ -974,22 +1061,19 @@ if __name__ == "__main__":
"--sonnet-input-len",
type=int,
default=550,
help=
"Number of input tokens per request, used only for sonnet dataset.",
help="Number of input tokens per request, used only for sonnet dataset.",
)
sonnet_group.add_argument(
"--sonnet-output-len",
type=int,
default=150,
help=
"Number of output tokens per request, used only for sonnet dataset.",
help="Number of output tokens per request, used only for sonnet dataset.",
)
sonnet_group.add_argument(
"--sonnet-prefix-len",
type=int,
default=200,
help=
"Number of prefix tokens per request, used only for sonnet dataset.",
help="Number of prefix tokens per request, used only for sonnet dataset.",
)
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
@ -998,22 +1082,21 @@ if __name__ == "__main__":
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.")
"from the ShareGPT dataset.",
)
random_group = parser.add_argument_group("random dataset options")
random_group.add_argument(
"--random-input-len",
type=int,
default=1024,
help=
"Number of input tokens per request, used only for random sampling.",
help="Number of input tokens per request, used only for random sampling.",
)
random_group.add_argument(
"--random-output-len",
type=int,
default=128,
help=
"Number of output tokens per request, used only for random sampling.",
help="Number of output tokens per request, used only for random sampling.",
)
random_group.add_argument(
"--random-range-ratio",
@ -1028,23 +1111,23 @@ if __name__ == "__main__":
"--random-prefix-len",
type=int,
default=0,
help=("Number of fixed prefix tokens before the random context "
help=(
"Number of fixed prefix tokens before the random context "
"in a request. "
"The total input length is the sum of `random-prefix-len` and "
"a random "
"context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."),
"input_len * (1 + range_ratio)]."
),
)
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset",
type=str,
default=None,
help="Subset of the HF dataset.")
hf_group.add_argument("--hf-split",
type=str,
default=None,
help="Split of the HF dataset.")
hf_group.add_argument(
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
)
hf_group.add_argument(
"--hf-split", type=str, default=None, help="Split of the HF dataset."
)
hf_group.add_argument(
"--hf-output-len",
type=int,
@ -1058,52 +1141,58 @@ if __name__ == "__main__":
"--top-p",
type=float,
default=None,
help="Top-p sampling parameter. Only has effect on openai-compatible "
"backends.")
help="Top-p sampling parameter. Only has effect on openai-compatible backends.",
)
sampling_group.add_argument(
"--top-k",
type=int,
default=None,
help="Top-k sampling parameter. Only has effect on openai-compatible "
"backends.")
help="Top-k sampling parameter. Only has effect on openai-compatible backends.",
)
sampling_group.add_argument(
"--min-p",
type=float,
default=None,
help="Min-p sampling parameter. Only has effect on openai-compatible "
"backends.")
help="Min-p sampling parameter. Only has effect on openai-compatible backends.",
)
sampling_group.add_argument(
"--temperature",
type=float,
default=None,
help="Temperature sampling parameter. Only has effect on "
"openai-compatible backends. If not specified, default to greedy "
"decoding (i.e. temperature==0.0).")
"decoding (i.e. temperature==0.0).",
)
parser.add_argument(
'--tokenizer-mode',
"--tokenizer-mode",
type=str,
default="auto",
choices=['auto', 'slow', 'mistral', 'custom'],
choices=["auto", "slow", "mistral", "custom"],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
"always use the slow tokenizer. \n* "
'"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.')
'"custom" will use --tokenizer to select the preregistered tokenizer.',
)
parser.add_argument("--served-model-name",
parser.add_argument(
"--served-model-name",
type=str,
default=None,
help="The model name used in the API. "
"If not specified, the model name will be the "
"same as the ``--model`` argument. ")
"same as the ``--model`` argument. ",
)
parser.add_argument("--lora-modules",
nargs='+',
parser.add_argument(
"--lora-modules",
nargs="+",
default=None,
help="A subset of LoRA module names passed in when "
"launching the server. For each request, the "
"script chooses a LoRA module at random.")
"script chooses a LoRA module at random.",
)
args = parser.parse_args()

View File

@ -19,6 +19,7 @@ On the client side, run:
--endpoint /generate_stream
to the end of the command above.
"""
import argparse
import asyncio
import copy
@ -36,11 +37,15 @@ from typing import Optional
import datasets
import numpy as np
import pandas as pd
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
RequestFuncOutput)
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
from backend_request_func import (
ASYNC_REQUEST_FUNCS,
RequestFuncInput,
RequestFuncOutput,
)
try:
from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError:
@ -52,7 +57,8 @@ except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
from vllm.v1.structured_output.backend_xgrammar import (
has_xgrammar_unsupported_json_features)
has_xgrammar_unsupported_json_features,
)
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@ -98,6 +104,7 @@ class SampleRequest:
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
"""
prompt: str
prompt_len: int
expected_output_len: int
@ -106,31 +113,27 @@ class SampleRequest:
completion: str = None
def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> list[SampleRequest]:
if args.dataset == 'json' or args.dataset == 'json-unique':
def sample_requests(
tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace
) -> list[SampleRequest]:
if args.dataset == "json" or args.dataset == "json-unique":
if args.json_schema_path is None:
dir_path = os.path.dirname(os.path.realpath(__file__))
args.json_schema_path = os.path.join(dir_path,
"structured_schemas",
"structured_schema_1.json")
args.json_schema_path = os.path.join(
dir_path, "structured_schemas", "structured_schema_1.json"
)
json_schemas = []
with open(args.json_schema_path) as f:
schema = json.load(f)
if args.dataset == 'json-unique':
json_schemas = [
copy.deepcopy(schema) for _ in range(args.num_prompts)
]
if args.dataset == "json-unique":
json_schemas = [copy.deepcopy(schema) for _ in range(args.num_prompts)]
for i in range(len(json_schemas)):
if "properties" not in json_schemas[i]:
json_schemas[i]["properties"] = {}
json_schemas[i]["properties"][
f"__optional_field_{uuid.uuid4()}"] = {
"type":
"string",
"description":
"An unique optional field to avoid cached schemas"
json_schemas[i]["properties"][f"__optional_field_{uuid.uuid4()}"] = {
"type": "string",
"description": "An unique optional field to avoid cached schemas",
}
else:
json_schemas = [schema] * args.num_prompts
@ -142,11 +145,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
return json_schemas[index % len(json_schemas)]
requests = [
SampleRequest(prompt=gen_prompt(i),
SampleRequest(
prompt=gen_prompt(i),
prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
expected_output_len=args.output_len,
schema=get_schema(i),
structure_type=args.structure_type)
structure_type=args.structure_type,
)
for i in range(args.num_prompts)
]
@ -170,11 +175,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(prompt=prompt,
SampleRequest(
prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=schema,
structure_type=args.structure_type)
structure_type=args.structure_type,
)
for _ in range(args.num_prompts)
]
@ -188,11 +195,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(prompt=prompt,
SampleRequest(
prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=regex,
structure_type=args.structure_type)
structure_type=args.structure_type,
)
for _ in range(args.num_prompts)
]
@ -203,48 +212,55 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(prompt=prompt,
SampleRequest(
prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=choice,
structure_type=args.structure_type)
structure_type=args.structure_type,
)
for _ in range(args.num_prompts)
]
elif args.dataset == "xgrammar_bench":
requests: list[SampleRequest] = []
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
split="train")
dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train")
full_dataset_len = len(dataset)
def _filter_func(item):
import json
schema = json.loads(item["schema"])
return not has_xgrammar_unsupported_json_features(schema)
dataset = dataset.filter(_filter_func)
num_filtered_out = full_dataset_len - len(dataset)
print(f"dataset has {len(dataset)} entries after filtering "
f"out {num_filtered_out} entries with unsupported features")
print(
f"dataset has {len(dataset)} entries after filtering "
f"out {num_filtered_out} entries with unsupported features"
)
len_dataset = len(dataset)
for data_point_idx in range(args.num_prompts):
idx = data_point_idx
while idx >= len_dataset:
idx -= len_dataset
schema = dataset["schema"][idx]
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
tokenize=False,
add_generation_prompt=True)
prompt = tokenizer.apply_chat_template(
dataset["prompt"][idx], tokenize=False, add_generation_prompt=True
)
input_len = len(tokenizer(prompt).input_ids)
completion = dataset["completion"][idx]
requests.append(
SampleRequest(prompt=prompt,
SampleRequest(
prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=schema,
structure_type=args.structure_type,
completion=completion))
completion=completion,
)
)
return requests
@ -276,7 +292,8 @@ async def get_request(
# Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}.")
f"A positive burstiness factor is expected, but given {burstiness}."
)
theta = 1.0 / (request_rate * burstiness)
for i, request in enumerate(input_requests):
@ -318,8 +335,8 @@ def calculate_metrics(
# multiple output tokens may be bundled together
# Note : this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text,
add_special_tokens=False).input_ids)
tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
)
actual_output_lens.append(output_len)
total_input += input_requests[i].prompt_len
tpot = 0
@ -343,16 +360,19 @@ def calculate_metrics(
if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts)
slo_values.append(goodput_config_dict["ttft"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(
goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots)
slo_values.append(goodput_config_dict["tpot"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(
goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
if "e2el" in goodput_config_dict:
valid_metrics.append(e2els)
slo_values.append(goodput_config_dict["e2el"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(
goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
@ -363,7 +383,8 @@ def calculate_metrics(
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
stacklevel=2,
)
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
@ -372,27 +393,31 @@ def calculate_metrics(
request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_ttft_ms=np.mean(ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
mean_ttft_ms=np.mean(ttfts or 0)
* 1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
for p in selected_percentiles],
percentiles_ttft_ms=[
(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
],
mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
for p in selected_percentiles],
percentiles_tpot_ms=[
(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
],
mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
for p in selected_percentiles],
percentiles_itl_ms=[
(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
],
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles],
percentiles_e2el_ms=[
(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
],
)
return metrics, actual_output_lens
@ -429,12 +454,13 @@ async def benchmark(
print("Starting initial single prompt test run...")
structured_output_req_idx = random.sample(
range(len(input_requests)),
int(len(input_requests) * structured_output_ratio))
range(len(input_requests)), int(len(input_requests) * structured_output_ratio)
)
test_request = input_requests[0]
test_req_extra_body = (prepare_extra_body(test_request)
if 0 in structured_output_req_idx else None)
test_req_extra_body = (
prepare_extra_body(test_request) if 0 in structured_output_req_idx else None
)
test_input = RequestFuncInput(
model=model_id,
prompt=test_request.prompt,
@ -448,7 +474,8 @@ async def benchmark(
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}")
f"are correctly specified. Error: {test_output.error}"
)
else:
print("Initial test run completed. Starting main benchmark run...")
@ -467,10 +494,7 @@ async def benchmark(
if profile_output.success:
print("Profiler started")
if burstiness == 1.0:
distribution = "Poisson process"
else:
distribution = "Gamma distribution"
distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
print(f"Traffic request rate: {request_rate}")
print(f"Burstiness factor: {burstiness} ({distribution})")
@ -482,24 +506,21 @@ async def benchmark(
# and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext())
semaphore = (asyncio.Semaphore(max_concurrency)
if max_concurrency else None)
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def limited_request_func(request_func_input, pbar):
if semaphore is None:
return await request_func(request_func_input=request_func_input,
pbar=pbar)
return await request_func(request_func_input=request_func_input, pbar=pbar)
async with semaphore:
return await request_func(request_func_input=request_func_input,
pbar=pbar)
return await request_func(request_func_input=request_func_input, pbar=pbar)
benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = []
expected: list[str] = []
async for i, request in get_request(input_requests, request_rate,
burstiness):
extra_body = prepare_extra_body(
request) if i in structured_output_req_idx else None
async for i, request in get_request(input_requests, request_rate, burstiness):
extra_body = (
prepare_extra_body(request) if i in structured_output_req_idx else None
)
request_func_input = RequestFuncInput(
model=model_id,
prompt=request.prompt,
@ -512,8 +533,9 @@ async def benchmark(
expected.append(request.completion)
tasks.append(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input,
pbar=pbar)))
limited_request_func(request_func_input=request_func_input, pbar=pbar)
)
)
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile:
@ -545,54 +567,58 @@ async def benchmark(
goodput_config_dict=goodput_config_dict,
)
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
print(
"{:<40} {:<10.2f}".format(
"Request throughput (req/s):", metrics.request_throughput
)
)
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
print(
"{:<40} {:<10.2f}".format(
"Request goodput (req/s):", metrics.request_goodput
)
)
print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Total Token throughput (tok/s):", metrics.total_token_throughput
)
)
result = {
"duration":
benchmark_duration,
"completed":
metrics.completed,
"total_input_tokens":
metrics.total_input,
"total_output_tokens":
metrics.total_output,
"request_throughput":
metrics.request_throughput,
"output_throughput":
metrics.output_throughput,
"total_token_throughput":
metrics.total_token_throughput,
"ttft_description":
pd.Series([output.ttft for output in outputs]).describe().to_dict(),
"tpot_description":
pd.Series([output.tpot for output in outputs]).describe().to_dict(),
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"ttft_description": pd.Series([output.ttft for output in outputs])
.describe()
.to_dict(),
"tpot_description": pd.Series([output.tpot for output in outputs])
.describe()
.to_dict(),
"input_lens": [output.prompt_len for output in outputs],
"output_lens":
actual_output_lens,
"output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"errors": [output.error for output in outputs],
}
ret = [{
'generated': output.generated_text,
'expected': gt
} for output, gt in zip(outputs, expected)]
ret = [
{"generated": output.generated_text, "expected": gt}
for output, gt in zip(outputs, expected)
]
def process_one_metric(
# E.g., "ttft"
@ -606,29 +632,35 @@ async def benchmark(
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
print("{:<40} {:<10.2f}".format(
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms")))
getattr(metrics, f"median_{metric_attribute_name}_ms"),
)
)
result[f"mean_{metric_attribute_name}_ms"] = getattr(
metrics, f"mean_{metric_attribute_name}_ms")
metrics, f"mean_{metric_attribute_name}_ms"
)
result[f"median_{metric_attribute_name}_ms"] = getattr(
metrics, f"median_{metric_attribute_name}_ms")
metrics, f"median_{metric_attribute_name}_ms"
)
result[f"std_{metric_attribute_name}_ms"] = getattr(
metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}_ms"):
metrics, f"std_{metric_attribute_name}_ms"
)
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
value))
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
@ -638,13 +670,13 @@ async def benchmark(
def evaluate(ret, args):
def _eval_correctness_json(expected, actual):
# extract json string from string using regex
import re
actual = actual.replace('\n', '').replace(' ', '').strip()
actual = actual.replace("\n", "").replace(" ", "").strip()
try:
actual = re.search(r'\{.*\}', actual).group()
actual = re.search(r"\{.*\}", actual).group()
actual = json.loads(actual)
except Exception:
return False
@ -656,28 +688,32 @@ def evaluate(ret, args):
def _eval_correctness_regex(expected, actual):
import re
return re.match(args.regex, actual) is not None
def _eval_correctness(expected, actual):
if args.structure_type == 'guided_json':
if args.structure_type == "guided_json":
return _eval_correctness_json(expected, actual)
elif args.structure_type == 'guided_regex':
elif args.structure_type == "guided_regex":
return _eval_correctness_regex(expected, actual)
elif args.structure_type == 'guided_choice':
elif args.structure_type == "guided_choice":
return _eval_correctness_choice(expected, actual)
else:
return None
scores = []
for res in ret:
score = _eval_correctness(res['expected'], res['generated'])
res['correctness'] = score
score = _eval_correctness(res["expected"], res["generated"])
res["correctness"] = score
scores.append(score)
not_none_scores = [score for score in scores if score is not None]
return (sum(not_none_scores) / len(not_none_scores) *
100) if len(not_none_scores) > 0 else None
return (
(sum(not_none_scores) / len(not_none_scores) * 100)
if len(not_none_scores) > 0
else None
)
def parse_goodput(slo_pairs):
@ -689,9 +725,10 @@ def parse_goodput(slo_pairs):
except ValueError as err:
raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" "
'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err
"number in milliseconds."
) from err
return goodput_config_dict
@ -705,12 +742,14 @@ def check_goodput_args(args):
raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of "
f"{str(VALID_NAMES)}. ")
f"{str(VALID_NAMES)}. "
)
if slo_val < 0:
raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be "
"non-negative.")
"non-negative."
)
return goodput_config_dict
@ -736,19 +775,19 @@ def main(args: argparse.Namespace):
tokenizer_mode=args.tokenizer_mode,
)
if args.dataset == 'grammar':
args.structure_type = 'guided_grammar'
elif args.dataset == 'regex':
args.structure_type = 'guided_regex'
elif args.dataset == 'choice':
args.structure_type = 'guided_choice'
if args.dataset == "grammar":
args.structure_type = "guided_grammar"
elif args.dataset == "regex":
args.structure_type = "guided_regex"
elif args.dataset == "choice":
args.structure_type = "guided_choice"
else:
args.structure_type = 'guided_json'
args.structure_type = "guided_json"
if args.no_structured_output:
args.structured_output_ratio = 0
if args.save_results:
result_file_name = f'{args.structured_output_ratio}guided'
result_file_name = f"{args.structured_output_ratio}guided"
result_file_name += f"_{backend}"
result_file_name += f"_{args.request_rate}qps"
result_file_name += f"_{args.model.split('/')[-1]}"
@ -776,36 +815,29 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[
float(p) for p in args.metric_percentiles.split(",")
],
selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
ignore_eos=args.ignore_eos,
max_concurrency=args.max_concurrency,
structured_output_ratio=args.structured_output_ratio,
goodput_config_dict=goodput_config_dict,
))
)
)
# Save config and results to json
score = evaluate(ret, args)
print("correct_rate(%)", score, '\n')
print("correct_rate(%)", score, "\n")
if args.save_results:
results = {
"backend":
backend,
"model_id":
model_id,
"tokenizer_id":
tokenizer_id,
"num_prompts":
args.num_prompts,
"request_rate":
args.request_rate if args.request_rate < float("inf") else "inf",
"burstiness":
args.burstiness,
"max_concurrency":
args.max_concurrency,
"correct_rate(%)":
score
"backend": backend,
"model_id": model_id,
"tokenizer_id": tokenizer_id,
"num_prompts": args.num_prompts,
"request_rate": args.request_rate
if args.request_rate < float("inf")
else "inf",
"burstiness": args.burstiness,
"max_concurrency": args.max_concurrency,
"correct_rate(%)": score,
}
results = {"outputs": ret, **results, **benchmark_result}
@ -814,13 +846,14 @@ def main(args: argparse.Namespace):
result_file_name = args.result_filename
if args.result_dir:
result_file_name = os.path.join(args.result_dir, result_file_name)
with open(result_file_name, "w", encoding='utf-8') as outfile:
with open(result_file_name, "w", encoding="utf-8") as outfile:
json.dump(results, outfile, indent=4)
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.")
description="Benchmark the online serving throughput."
)
parser.add_argument(
"--backend",
type=str,
@ -842,16 +875,14 @@ if __name__ == "__main__":
default="/v1/completions",
help="API endpoint.",
)
parser.add_argument("--dataset",
default='json',
choices=[
'json', 'json-unique', 'grammar', 'regex',
'choice', 'xgrammar_bench'
])
parser.add_argument("--json-schema-path",
type=str,
default=None,
help="Path to json schema.")
parser.add_argument(
"--dataset",
default="json",
choices=["json", "json-unique", "grammar", "regex", "choice", "xgrammar_bench"],
)
parser.add_argument(
"--json-schema-path", type=str, default=None, help="Path to json schema."
)
parser.add_argument(
"--max-concurrency",
type=int,
@ -863,7 +894,8 @@ if __name__ == "__main__":
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.")
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument(
"--model",
type=str,
@ -873,15 +905,13 @@ if __name__ == "__main__":
parser.add_argument(
"--tokenizer",
type=str,
help=
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default="auto",
help=
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument(
"--num-prompts",
@ -958,44 +988,51 @@ if __name__ == "__main__":
"--ignore-eos",
action="store_true",
help="Set ignore_eos flag when sending the benchmark request."
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
)
parser.add_argument(
"--percentile-metrics",
type=str,
default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. "
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
"Default value is \"ttft,tpot,itl\".")
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
'Default value is "ttft,tpot,itl".',
)
parser.add_argument(
"--metric-percentiles",
type=str,
default="99",
help="Comma-separated list of percentiles for selected metrics. "
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
"Default value is \"99\". "
"Use \"--percentile-metrics\" to select metrics.",
'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
'Default value is "99". '
'Use "--percentile-metrics" to select metrics.',
)
parser.add_argument(
"--goodput",
nargs="+",
required=False,
help="Specify service level objectives for goodput as \"KEY:VALUE\" "
help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in "
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are "
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
'"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
parser.add_argument("--no-structured-output",
action='store_true',
parser.add_argument(
"--no-structured-output",
action="store_true",
default=False,
help="Whether to disable JSON decoding or not.")
parser.add_argument("--structured-output-ratio",
help="Whether to disable JSON decoding or not.",
)
parser.add_argument(
"--structured-output-ratio",
type=float,
default=1.0,
help="Ratio of Structured Outputs requests")
help="Ratio of Structured Outputs requests",
)
args = parser.parse_args()
main(args)

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Benchmark offline inference throughput."""
import argparse
import dataclasses
import json
@ -11,18 +12,25 @@ from typing import Any, Optional, Union
import torch
import uvloop
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
ConversationDataset, InstructCoderDataset,
RandomDataset, SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
from benchmark_dataset import (
AIMODataset,
BurstGPTDataset,
ConversationDataset,
InstructCoderDataset,
RandomDataset,
SampleRequest,
ShareGPTDataset,
SonnetDataset,
VisionArenaDataset,
)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
build_async_engine_client_from_engine_args,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
@ -37,23 +45,30 @@ def run_vllm(
disable_detokenize: bool = False,
) -> tuple[float, Optional[list[RequestOutput]]]:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len >= (
request.prompt_len + request.expected_output_len)
for request in requests), (
llm.llm_engine.model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.")
" prompt_len and expected_output_len for all requests."
)
# Add the requests to the engine.
prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = []
for request in requests:
prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
TokensPrompt(
prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data,
)
if "prompt_token_ids" in request.prompt
else TextPrompt(
prompt=request.prompt, multi_modal_data=request.multi_modal_data
)
)
sampling_params.append(
SamplingParams(
n=n,
@ -62,7 +77,8 @@ def run_vllm(
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
))
)
)
lora_requests: Optional[list[LoRARequest]] = None
if engine_args.enable_lora:
lora_requests = [request.lora_request for request in requests]
@ -72,10 +88,9 @@ def run_vllm(
outputs = None
if not use_beam_search:
start = time.perf_counter()
outputs = llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
outputs = llm.generate(
prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
)
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
@ -91,7 +106,8 @@ def run_vllm(
beam_width=n,
max_tokens=output_len,
ignore_eos=True,
))
),
)
end = time.perf_counter()
return end - start, outputs
@ -100,21 +116,25 @@ def run_vllm_chat(
requests: list[SampleRequest],
n: int,
engine_args: EngineArgs,
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
disable_detokenize: bool = False,
) -> tuple[float, list[RequestOutput]]:
"""
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
multimodal models as it properly handles multimodal inputs and chat
formatting. For non-multimodal models, use run_vllm() instead.
"""
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len >= (
request.prompt_len + request.expected_output_len)
for request in requests), (
llm.llm_engine.model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of "
"prompt_len and expected_output_len for all requests.")
"prompt_len and expected_output_len for all requests."
)
prompts = []
sampling_params: list[SamplingParams] = []
@ -128,7 +148,8 @@ def run_vllm_chat(
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
))
)
)
start = time.perf_counter()
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
@ -145,14 +166,17 @@ async def run_vllm_async(
from vllm import SamplingParams
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
engine_args, disable_frontend_multiprocessing
) as llm:
model_config = await llm.get_model_config()
assert all(
model_config.max_model_len >= (request.prompt_len +
request.expected_output_len)
for request in requests), (
model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.")
" prompt_len and expected_output_len for all requests."
)
# Add the requests to the engine.
prompts: list[Union[TextPrompt, TokensPrompt]] = []
@ -160,11 +184,15 @@ async def run_vllm_async(
lora_requests: list[Optional[LoRARequest]] = []
for request in requests:
prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
TokensPrompt(
prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data,
)
if "prompt_token_ids" in request.prompt
else TextPrompt(
prompt=request.prompt, multi_modal_data=request.multi_modal_data
)
)
sampling_params.append(
SamplingParams(
n=n,
@ -173,17 +201,16 @@ async def run_vllm_async(
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
))
)
)
lora_requests.append(request.lora_request)
generators = []
start = time.perf_counter()
for i, (prompt, sp,
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
generator = llm.generate(prompt,
sp,
lora_request=lr,
request_id=f"test{i}")
for i, (prompt, sp, lr) in enumerate(
zip(prompts, sampling_params, lora_requests)
):
generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
@ -202,7 +229,8 @@ def run_hf(
disable_detokenize: bool = False,
) -> float:
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
@ -225,14 +253,15 @@ def run_hf(
# Check if we can add more requests to the batch.
next_prompt_len = requests[i + 1].prompt_len
next_output_len = requests[i + 1].expected_output_len
if (max(max_prompt_len, next_prompt_len) +
max(max_output_len, next_output_len)) <= 2048:
if (
max(max_prompt_len, next_prompt_len)
+ max(max_output_len, next_output_len)
) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt",
padding=True).input_ids
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=True,
@ -262,6 +291,7 @@ def run_mii(
output_len: int,
) -> float:
from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [request.prompt for request in requests]
@ -273,8 +303,9 @@ def run_mii(
return end - start
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: dict[str, Any]) -> None:
def save_to_pytorch_benchmark_format(
args: argparse.Namespace, results: dict[str, Any]
) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
@ -282,9 +313,9 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
"tokens_per_second": [results["tokens_per_second"]],
},
extra_info={
k: results[k]
for k in ["elapsed_time", "num_requests", "total_num_tokens"]
})
k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
},
)
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
@ -316,7 +347,8 @@ def get_requests(args, tokenizer):
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_name == "sonnet":
assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset.")
"Tokenizer/model must have chat template for sonnet dataset."
)
dataset_cls = SonnetDataset
sample_kwargs["prefix_len"] = args.prefix_len
sample_kwargs["return_prompt_formatted"] = True
@ -325,21 +357,21 @@ def get_requests(args, tokenizer):
elif args.dataset_name == "hf":
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
common_kwargs["dataset_subset"] = None
common_kwargs["dataset_split"] = "train"
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = InstructCoderDataset
common_kwargs['dataset_split'] = "train"
common_kwargs["dataset_split"] = "train"
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = ConversationDataset
common_kwargs['dataset_subset'] = args.hf_subset
common_kwargs['dataset_split'] = args.hf_split
common_kwargs["dataset_subset"] = args.hf_subset
common_kwargs["dataset_split"] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_cls = AIMODataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
common_kwargs["dataset_subset"] = None
common_kwargs["dataset_split"] = "train"
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
@ -354,10 +386,10 @@ def main(args: argparse.Namespace):
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
args.tokenizer, trust_remote_code=args.trust_remote_code
)
requests = get_requests(args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None
for request in requests)
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
request_outputs: Optional[list[RequestOutput]] = None
if args.backend == "vllm":
if args.async_engine:
@ -368,23 +400,34 @@ def main(args: argparse.Namespace):
AsyncEngineArgs.from_cli_args(args),
args.disable_frontend_multiprocessing,
args.disable_detokenize,
))
)
)
else:
elapsed_time, request_outputs = run_vllm(
requests, args.n, EngineArgs.from_cli_args(args),
args.disable_detokenize)
requests,
args.n,
EngineArgs.from_cli_args(args),
args.disable_detokenize,
)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.hf_max_batch_size, args.trust_remote_code,
args.disable_detokenize)
elapsed_time = run_hf(
requests,
args.model,
tokenizer,
args.n,
args.hf_max_batch_size,
args.trust_remote_code,
args.disable_detokenize,
)
elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len)
elapsed_time = run_mii(
requests, args.model, args.tensor_parallel_size, args.output_len
)
elif args.backend == "vllm-chat":
elapsed_time, request_outputs = run_vllm_chat(
requests, args.n, EngineArgs.from_cli_args(args),
args.disable_detokenize)
requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
)
else:
raise ValueError(f"Unknown backend: {args.backend}")
@ -396,28 +439,31 @@ def main(args: argparse.Namespace):
for ro in request_outputs:
if not isinstance(ro, RequestOutput):
continue
total_prompt_tokens += len(
ro.prompt_token_ids) if ro.prompt_token_ids else 0
total_output_tokens += sum(
len(o.token_ids) for o in ro.outputs if o)
total_prompt_tokens += (
len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
)
total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
total_num_tokens = total_prompt_tokens + total_output_tokens
else:
total_num_tokens = sum(r.prompt_len + r.expected_output_len
for r in requests)
total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
total_output_tokens = sum(r.expected_output_len for r in requests)
total_prompt_tokens = total_num_tokens - total_output_tokens
if is_multi_modal and args.backend != "vllm-chat":
print("\033[91mWARNING\033[0m: Multi-modal request with "
print(
"\033[91mWARNING\033[0m: Multi-modal request with "
f"{args.backend} backend detected. The "
"following metrics are not accurate because image tokens are not"
" counted. See vllm-project/vllm/issues/9778 for details.")
" counted. See vllm-project/vllm/issues/9778 for details."
)
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
# vllm-chat backend counts the image tokens now
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
print(
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
)
print(f"Total num prompt tokens: {total_prompt_tokens}")
print(f"Total num output tokens: {total_output_tokens}")
@ -445,7 +491,8 @@ def validate_args(args):
warnings.warn(
"The '--dataset' argument will be deprecated in the next release. "
"Please use '--dataset-name' and '--dataset-path' instead.",
stacklevel=2)
stacklevel=2,
)
args.dataset_path = args.dataset
if not getattr(args, "tokenizer", None):
@ -458,9 +505,8 @@ def validate_args(args):
# === Dataset Configuration ===
if not args.dataset and not args.dataset_path:
print(
"When dataset path is not set, it will default to random dataset")
args.dataset_name = 'random'
print("When dataset path is not set, it will default to random dataset")
args.dataset_name = "random"
if args.input_len is None:
raise ValueError("input_len must be provided for a random dataset")
@ -469,40 +515,54 @@ def validate_args(args):
# when dataset_name is 'hf'
if args.dataset_name != "hf" and (
getattr(args, "hf_subset", None) is not None
or getattr(args, "hf_split", None) is not None):
warnings.warn("--hf-subset and --hf-split will be ignored \
or getattr(args, "hf_split", None) is not None
):
warnings.warn(
"--hf-subset and --hf-split will be ignored \
since --dataset-name is not 'hf'.",
stacklevel=2)
stacklevel=2,
)
elif args.dataset_name == "hf":
if args.dataset_path in (
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
| ConversationDataset.SUPPORTED_DATASET_PATHS):
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
| AIMODataset.SUPPORTED_DATASET_PATHS):
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
| ConversationDataset.SUPPORTED_DATASET_PATHS
):
assert args.backend == "vllm-chat", (
f"{args.dataset_path} needs to use vllm-chat as the backend."
) # noqa: E501
elif args.dataset_path in (
InstructCoderDataset.SUPPORTED_DATASET_PATHS
| AIMODataset.SUPPORTED_DATASET_PATHS
):
assert args.backend == "vllm", (
f"{args.dataset_path} needs to use vllm as the backend."
) # noqa: E501
else:
raise ValueError(
f"{args.dataset_path} is not supported by hf dataset.")
raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
# --random-range-ratio: only used when dataset_name is 'random'
if args.dataset_name != 'random' and args.random_range_ratio is not None:
warnings.warn("--random-range-ratio will be ignored since \
if args.dataset_name != "random" and args.random_range_ratio is not None:
warnings.warn(
"--random-range-ratio will be ignored since \
--dataset-name is not 'random'.",
stacklevel=2)
stacklevel=2,
)
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
# set.
if args.dataset_name not in {"random", "sonnet", None
} and args.prefix_len is not None:
warnings.warn("--prefix-len will be ignored since --dataset-name\
if (
args.dataset_name not in {"random", "sonnet", None}
and args.prefix_len is not None
):
warnings.warn(
"--prefix-len will be ignored since --dataset-name\
is not 'random', 'sonnet', or not set.",
stacklevel=2)
stacklevel=2,
)
# === LoRA Settings ===
if getattr(args, "enable_lora", False) and args.backend != "vllm":
raise ValueError(
"LoRA benchmarking is only supported for vLLM backend")
raise ValueError("LoRA benchmarking is only supported for vLLM backend")
if getattr(args, "enable_lora", False) and args.lora_path is None:
raise ValueError("LoRA path must be provided when enable_lora is True")
@ -512,8 +572,10 @@ def validate_args(args):
if args.backend != "hf" and args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
if args.backend in {"hf", "mii"} and getattr(args, "quantization",
None) is not None:
if (
args.backend in {"hf", "mii"}
and getattr(args, "quantization", None) is not None
):
raise ValueError("Quantization is only for vLLM backend.")
if args.backend == "mii" and args.dtype != "auto":
@ -521,29 +583,32 @@ def validate_args(args):
if args.backend == "mii" and args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.backend == "mii" and args.tokenizer != args.model:
raise ValueError(
"Tokenizer must be the same as the model for MII backend.")
raise ValueError("Tokenizer must be the same as the model for MII backend.")
# --data-parallel is not supported currently.
# https://github.com/vllm-project/vllm/issues/16222
if args.data_parallel_size > 1:
raise ValueError(
"Data parallel is not supported in offline benchmark, \
please use benchmark serving instead")
please use benchmark serving instead"
)
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
parser.add_argument(
"--backend",
type=str,
choices=["vllm", "hf", "mii", "vllm-chat"],
default="vllm")
default="vllm",
)
parser.add_argument(
"--dataset-name",
type=str,
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
help="Name of the dataset to benchmark on.",
default="sharegpt")
default="sharegpt",
)
parser.add_argument(
"--dataset",
type=str,
@ -551,57 +616,70 @@ if __name__ == "__main__":
help="Path to the ShareGPT dataset, will be deprecated in\
the next release. The dataset is expected to "
"be a json in form of list[dict[..., conversations: "
"list[dict[..., value: <prompt_or_response>]]]]")
parser.add_argument("--dataset-path",
type=str,
default=None,
help="Path to the dataset")
parser.add_argument("--input-len",
"list[dict[..., value: <prompt_or_response>]]]]",
)
parser.add_argument(
"--dataset-path", type=str, default=None, help="Path to the dataset"
)
parser.add_argument(
"--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
help="Input prompt length for each request",
)
parser.add_argument(
"--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--hf-max-batch-size",
"output length from the dataset.",
)
parser.add_argument(
"--n", type=int, default=1, help="Number of generated sequences per prompt."
)
parser.add_argument(
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
)
parser.add_argument(
"--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.")
help="Maximum batch size for HF backend.",
)
parser.add_argument(
'--output-json',
"--output-json",
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument("--async-engine",
action='store_true',
help="Path to save the throughput results in JSON format.",
)
parser.add_argument(
"--async-engine",
action="store_true",
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
help="Use vLLM async engine rather than LLM class.",
)
parser.add_argument(
"--disable-frontend-multiprocessing",
action="store_true",
default=False,
help="Disable decoupled async engine frontend.")
help="Disable decoupled async engine frontend.",
)
parser.add_argument(
"--disable-detokenize",
action="store_true",
help=("Do not detokenize the response (i.e. do not include "
"detokenization time in the measurement)"))
help=(
"Do not detokenize the response (i.e. do not include "
"detokenization time in the measurement)"
),
)
# LoRA
parser.add_argument(
"--lora-path",
type=str,
default=None,
help="Path to the LoRA adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.")
"a relative path, or a Hugging Face model identifier.",
)
parser.add_argument(
"--prefix-len",
type=int,
@ -615,7 +693,8 @@ if __name__ == "__main__":
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
"controls how much of the input is fixed lines versus "
"random lines, but the total input length remains approximately "
"input_len tokens.")
"input_len tokens.",
)
# random dataset
parser.add_argument(
"--random-range-ratio",
@ -629,14 +708,12 @@ if __name__ == "__main__":
)
# hf dtaset
parser.add_argument("--hf-subset",
type=str,
default=None,
help="Subset of the HF dataset.")
parser.add_argument("--hf-split",
type=str,
default=None,
help="Split of the HF dataset.")
parser.add_argument(
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
)
parser.add_argument(
"--hf-split", type=str, default=None, help="Split of the HF dataset."
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()

View File

@ -7,9 +7,9 @@ import os
from typing import Any
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
metrics: dict[str, list],
extra_info: dict[str, Any]) -> list:
def convert_to_pytorch_benchmark_format(
args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
) -> list:
"""
Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record
@ -37,12 +37,12 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
},
}
tp = record["benchmark"]["extra_info"]["args"].get(
"tensor_parallel_size")
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
# Save tensor_parallel_size parameter if it's part of the metadata
if not tp and "tensor_parallel_size" in extra_info:
record["benchmark"]["extra_info"]["args"][
"tensor_parallel_size"] = extra_info["tensor_parallel_size"]
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
extra_info["tensor_parallel_size"]
)
records.append(record)
@ -50,7 +50,6 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
class InfEncoder(json.JSONEncoder):
def clear_inf(self, o: Any):
if isinstance(o, dict):
return {k: self.clear_inf(v) for k, v in o.items()}

View File

@ -23,8 +23,9 @@ DEFAULT_TP_SIZES = [1]
# bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
**kwargs) -> TMeasurement:
def bench_fn(
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
min_run_time = 1
globals = {
@ -41,16 +42,18 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
).blocked_autorange(min_run_time=min_run_time)
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
def bench_int8(
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
assert dtype == torch.int8
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref):
@ -63,54 +66,107 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
timers = []
# pytorch impl - bfloat16
timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16),
b.to(dtype=torch.bfloat16)))
bench_fn(
label,
sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.bfloat16),
b.to(dtype=torch.bfloat16),
)
)
# pytorch impl - float16
timers.append(
bench_fn(label, sub_label,
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
bench_fn(
label,
sub_label,
"pytorch_fp16_fp16_fp16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.float16),
b.to(dtype=torch.float16),
)
)
# cutlass impl
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))
bench_fn(
label,
sub_label,
"cutlass_i8_i8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass with bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
bias))
bench_fn(
label,
sub_label,
"cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
# cutlass sparse impl
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.bfloat16))
bench_fn(
label,
sub_label,
"cutlass_i8_i8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass sparse with bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.bfloat16, bias))
bench_fn(
label,
sub_label,
"cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
return timers
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
def bench_fp8(
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
assert dtype == torch.float8_e4m3fn
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n,
k)
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref):
@ -124,13 +180,20 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# pytorch impl w. bf16
timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda")))
bench_fn(
label,
sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"),
)
)
# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
bench_fn(label,
bench_fn(
label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
@ -138,11 +201,14 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16))
out_dtype=torch.bfloat16,
)
)
# pytorch impl: bf16 output, with fp8 fast accum
timers.append(
bench_fn(label,
bench_fn(
label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
@ -151,11 +217,14 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True))
use_fast_accum=True,
)
)
# pytorch impl: fp16 output, without fp8 fast accum
timers.append(
bench_fn(label,
bench_fn(
label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm",
torch._scaled_mm,
@ -163,11 +232,14 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16))
out_dtype=torch.float16,
)
)
# pytorch impl: fp16 output, with fp8 fast accum
timers.append(
bench_fn(label,
bench_fn(
label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
torch._scaled_mm,
@ -176,45 +248,97 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16,
use_fast_accum=True))
use_fast_accum=True,
)
)
# cutlass impl: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass impl: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.bfloat16))
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass impl: fp16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.float16))
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.float16,
)
)
# cutlass impl: bf16 output, with bias
timers.append(
bench_fn(label, sub_label,
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.bfloat16, bias))
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
# cutlass impl: fp16 output, with bias
timers.append(
bench_fn(label, sub_label,
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.float16, bias.to(dtype=torch.float16)))
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.float16,
bias.to(dtype=torch.float16),
)
)
return timers
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
def bench(
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn:
@ -228,12 +352,12 @@ def print_timers(timers: Iterable[TMeasurement]):
compare.print()
def run(dtype: torch.dtype,
MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
def run(
dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]]
) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
f"MKN=({m}x{k}x{n})")
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})")
print_timers(timers)
results.extend(timers)
@ -241,10 +365,12 @@ def run(dtype: torch.dtype,
# output makers
def make_output(data: Iterable[TMeasurement],
def make_output(
data: Iterable[TMeasurement],
MKNs: Iterable[tuple[int, int, int]],
base_description: str,
timestamp=None):
timestamp=None,
):
print(f"== All Results {base_description} ====")
print_timers(data)
@ -258,8 +384,7 @@ def make_output(data: Iterable[TMeasurement],
def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs)
@ -319,7 +444,7 @@ def run_model_bench(args):
pkl.dump(all_data, f)
if __name__ == '__main__':
if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "int8":
@ -344,12 +469,15 @@ Benchmark Cutlass GEMM.
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter)
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--dtype",
parser.add_argument(
"--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['int8', 'fp8']")
help="Available options are ['int8', 'fp8']",
)
subparsers = parser.add_subparsers(dest="cmd")
square_parser = subparsers.add_parser("square_bench")
@ -368,19 +496,19 @@ Benchmark Cutlass GEMM.
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models",
model_parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys())
model_parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
model_parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
choices=WEIGHT_SHAPES.keys(),
)
model_parser.add_argument(
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
)
model_parser.add_argument(
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
)
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()

View File

@ -10,8 +10,9 @@ import vllm._custom_ops as ops
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
dtype=torch.float8_e4m3fn
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
@ -26,10 +27,11 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16)
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
def make_rand_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device="cuda") * 5
b = torch.randn((n, k), device="cuda").t() * 5
if dtype == torch.int8:
return to_int8(a), to_int8(b)
@ -49,9 +51,7 @@ def prune_to_2_4(tensor):
# Create binary mask
mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1,
index=indices,
src=torch.ones_like(indices, dtype=mask.dtype))
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back
pruned = reshaped * mask
@ -62,10 +62,11 @@ def prune_to_2_4(tensor):
return pruned.reshape(original_shape)
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device="cuda") * 5
b = torch.randn((n, k), device="cuda").t() * 5
b = prune_to_2_4(b.t()).t()
@ -86,9 +87,9 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
return b_compressed, e, a, b
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
m: int, n: int, k: int) -> \
tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
def make_n_rand_sparse_tensors(
num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int
) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
ABs = []
for _ in range(num_tensors):
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)

View File

@ -16,7 +16,8 @@ from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul)
w8a8_block_fp8_matmul,
)
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
@ -25,8 +26,9 @@ DEFAULT_TP_SIZES = [1]
# bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
**kwargs) -> TMeasurement:
def bench_fn(
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
min_run_time = 1
globals = {
@ -50,39 +52,42 @@ def bench_int8(
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
"""Benchmark INT8-based kernels."""
assert dtype == torch.int8
a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
azp = torch.zeros((m,), device="cuda", dtype=torch.int32)
azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32)
bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
),
"cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16
),
"cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16, bias
),
"cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj
),
"cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias
),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp
),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias
),
"pytorch_fp16_fp16_fp16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
"cutlass_i8_i8_bf16_scaled_mm":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
"cutlass_i8_i8_bf16_scaled_mm_bias":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
bias),
"cutlass_i8_i8_bf16_scaled_mm_azp":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj),
"cutlass_i8_i8_bf16_scaled_mm_azp_bias":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj, None, bias),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj, azp),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj, azp, bias),
}
timers = []
@ -102,67 +107,59 @@ def bench_fp8(
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
"""Benchmark FP8-based kernels."""
assert dtype == torch.float8_e4m3fn
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
a_cont = a.contiguous()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
block_scale_a = torch.rand((m, k // 128),
device="cuda",
dtype=torch.float32)
block_scale_b = torch.rand((k // 128, n // 128),
device="cuda",
dtype=torch.float32)
block_scale_a = torch.rand((m, k // 128), device="cuda", dtype=torch.float32)
block_scale_b = torch.rand((k // 128, n // 128), device="cuda", dtype=torch.float32)
block_scale_a_M_major = block_scale_a.t().contiguous().t()
block_scale_b_K_major = block_scale_b.t().contiguous().t()
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
print(m, k, n)
bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
),
"pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16
),
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True
),
"pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.bfloat16
),
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True
),
"cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16
),
"cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16
),
"cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16, bias
),
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
),
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16
),
"pytorch_fp16_fp16_fp16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
"pytorch_fp8_fp8_fp16_scaled_mm":
lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16),
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum":
lambda: torch._scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.float16,
use_fast_accum=True),
"pytorch_fp8_fp8_bf16_scaled_mm":
lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.bfloat16),
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum":
lambda: torch._scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True),
"cutlass_fp8_fp8_bf16_scaled_mm":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
"cutlass_fp8_fp8_fp16_scaled_mm":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16),
"cutlass_fp8_fp8_bf16_scaled_mm_bias":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
bias),
"cutlass_fp8_fp8_fp16_scaled_mm_bias":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16,
bias.to(dtype=torch.float16)),
"triton_fp8_fp8_fp16_scaled_mm_blockwise":
lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a,
block_scale_b.t(), (128, 128)),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise":
lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major,
block_scale_b_K_major, torch.float16),
}
timers = []
@ -175,13 +172,15 @@ def bench_fp8(
return timers
def bench(dtype: torch.dtype,
def bench(
dtype: torch.dtype,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
if dtype == torch.float8_e4m3fn:
@ -195,27 +194,33 @@ def print_timers(timers: Iterable[TMeasurement]):
compare.print()
def run(dtype: torch.dtype,
def run(
dtype: torch.dtype,
MKNs: Iterable[tuple[int, int, int]],
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype,
timers = bench(
dtype,
m,
k,
n,
f"scaled-{dtype}-gemm",
f"MKN=({m}x{k}x{n})",
bench_kernels=bench_kernels)
bench_kernels=bench_kernels,
)
print_timers(timers)
results.extend(timers)
return results
def make_output(data: Iterable[TMeasurement],
def make_output(
data: Iterable[TMeasurement],
MKNs: Iterable[tuple[int, int, int]],
base_description: str,
timestamp=None):
timestamp=None,
):
print(f"== All Results {base_description} ====")
print_timers(data)
@ -226,8 +231,7 @@ def make_output(data: Iterable[TMeasurement],
def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
make_output(data, MKNs, f"square_bench-{args.dtype}")
@ -285,7 +289,7 @@ def run_model_bench(args):
pkl.dump(all_data, f)
if __name__ == '__main__':
if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "int8":
@ -310,19 +314,21 @@ Benchmark Cutlass GEMM.
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter)
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--dtype",
parser.add_argument(
"--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['int8', 'fp8']")
help="Available options are ['int8', 'fp8']",
)
parser.add_argument(
"--kernels",
nargs="+",
type=str,
default=None,
help=
"Exact names of the kernels to benchmark. If not set, runs all kernels."
help="Exact names of the kernels to benchmark. If not set, runs all kernels.",
)
subparsers = parser.add_subparsers(dest="cmd")
@ -343,19 +349,19 @@ Benchmark Cutlass GEMM.
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models",
model_parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys())
model_parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
model_parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
choices=WEIGHT_SHAPES.keys(),
)
model_parser.add_argument(
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
)
model_parser.add_argument(
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
)
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()

View File

@ -12,39 +12,37 @@ app = Quart(__name__)
async def forward_request(url, data):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
async with session.post(url=url, json=data,
headers=headers) as response:
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
async with session.post(url=url, json=data, headers=headers) as response:
if response.status == 200:
# if response.headers.get('Transfer-Encoding') == 'chunked':
if True:
async for chunk_bytes in response.content.iter_chunked(
1024):
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
content = await response.read()
yield content
@app.route('/v1/completions', methods=['POST'])
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request['max_tokens'] = 1
prefill_request["max_tokens"] = 1
# finish prefill
async for _ in forward_request('http://localhost:8100/v1/completions',
prefill_request):
async for _ in forward_request(
"http://localhost:8100/v1/completions", prefill_request
):
continue
# return decode
generator = forward_request('http://localhost:8200/v1/completions',
original_request_data)
generator = forward_request(
"http://localhost:8200/v1/completions", original_request_data
)
response = await make_response(generator)
response.timeout = None
@ -53,11 +51,12 @@ async def handle_request():
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == '__main__':
if __name__ == "__main__":
app.run(port=8000)

View File

@ -8,7 +8,6 @@ from aiohttp import web
class RoundRobinProxy:
def __init__(self, target_ports):
self.target_ports = target_ports
self.port_cycle = itertools.cycle(self.target_ports)
@ -27,8 +26,9 @@ class RoundRobinProxy:
data=request.content,
) as response:
# Start sending the response
resp = web.StreamResponse(status=response.status,
headers=response.headers)
resp = web.StreamResponse(
status=response.status, headers=response.headers
)
await resp.prepare(request)
# Stream the response content
@ -45,11 +45,11 @@ class RoundRobinProxy:
async def main():
proxy = RoundRobinProxy([8100, 8200])
app = web.Application()
app.router.add_route('*', '/{path:.*}', proxy.handle_request)
app.router.add_route("*", "/{path:.*}", proxy.handle_request)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, 'localhost', 8000)
site = web.TCPSite(runner, "localhost", 8000)
await site.start()
print("Proxy server started on http://localhost:8000")
@ -58,5 +58,5 @@ async def main():
await asyncio.Event().wait()
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main())

View File

@ -6,43 +6,41 @@ import matplotlib.pyplot as plt
import pandas as pd
if __name__ == "__main__":
data = []
for name in ['disagg_prefill', 'chunked_prefill']:
for name in ["disagg_prefill", "chunked_prefill"]:
for qps in [2, 4, 6, 8]:
with open(f"results/{name}-qps-{qps}.json") as f:
x = json.load(f)
x['name'] = name
x['qps'] = qps
x["name"] = name
x["qps"] = qps
data.append(x)
df = pd.DataFrame.from_dict(data)
dis_df = df[df['name'] == 'disagg_prefill']
chu_df = df[df['name'] == 'chunked_prefill']
dis_df = df[df["name"] == "disagg_prefill"]
chu_df = df[df["name"] == "chunked_prefill"]
plt.style.use('bmh')
plt.rcParams['font.size'] = 20
plt.style.use("bmh")
plt.rcParams["font.size"] = 20
for key in [
'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms',
'median_itl_ms', 'p99_itl_ms'
"mean_ttft_ms",
"median_ttft_ms",
"p99_ttft_ms",
"mean_itl_ms",
"median_itl_ms",
"p99_itl_ms",
]:
fig, ax = plt.subplots(figsize=(11, 7))
plt.plot(dis_df['qps'],
dis_df[key],
label='disagg_prefill',
marker='o',
linewidth=4)
plt.plot(chu_df['qps'],
chu_df[key],
label='chunked_prefill',
marker='o',
linewidth=4)
plt.plot(
dis_df["qps"], dis_df[key], label="disagg_prefill", marker="o", linewidth=4
)
plt.plot(
chu_df["qps"], chu_df[key], label="chunked_prefill", marker="o", linewidth=4
)
ax.legend()
ax.set_xlabel('QPS')
ax.set_xlabel("QPS")
ax.set_ylabel(key)
ax.set_ylim(bottom=0)
fig.savefig(f'results/{key}.png')
fig.savefig(f"results/{key}.png")
plt.close(fig)

View File

@ -24,10 +24,12 @@ class bench_params_t:
dtype: torch.dtype
def description(self):
return (f'N {self.num_tokens} '
f'x D {self.hidden_size} '
f'x R {self.add_residual} '
f'x DT {self.dtype}')
return (
f"N {self.num_tokens} "
f"x D {self.hidden_size} "
f"x R {self.add_residual} "
f"x DT {self.dtype}"
)
def get_bench_params() -> list[bench_params_t]:
@ -38,15 +40,19 @@ def get_bench_params() -> list[bench_params_t]:
DTYPES = [torch.bfloat16, torch.float]
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
bench_params = list(map(lambda x: \
bench_params_t(x[0], x[1], x[2], x[3]), combinations))
bench_params = list(
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
)
return bench_params
# Reference impls
def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
def unfused_int8_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
quant_dtype: torch.dtype,
):
# Norm
torch_out = None
if residual is None:
@ -58,9 +64,12 @@ def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
torch_out, _, _ = ops.scaled_int8_quant(torch_out)
def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
def unfused_fp8_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
quant_dtype: torch.dtype,
):
# Norm
torch_out = None
if residual is None:
@ -76,19 +85,24 @@ def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
out, _ = ops.rms_norm_dynamic_per_token_quant(x,
rms_norm_layer.weight,
1e-6,
quant_dtype,
residual=residual)
quant_dtype: torch.dtype,
):
out, _ = ops.rms_norm_dynamic_per_token_quant(
x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
)
# Bench functions
def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
quant_dtype: torch.dtype, label: str, sub_label: str,
fn: Callable, description: str) -> TMeasurement:
def bench_fn(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: torch.Tensor,
quant_dtype: torch.dtype,
label: str,
sub_label: str,
fn: Callable,
description: str,
) -> TMeasurement:
min_run_time = 1
globals = {
@ -106,43 +120,81 @@ def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench(params: bench_params_t, label: str, sub_label: str) \
-> Iterable[TMeasurement]:
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
# Make inputs
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
# Make weights
layer.weight.data.normal_(mean=1.0, std=0.1)
# Make inputs
scale = 1 / params.hidden_size
x = torch.randn(params.num_tokens,
params.hidden_size,
dtype=params.dtype,
device='cuda') * scale
residual = (torch.randn_like(x) * scale).to(device='cuda') \
if params.add_residual else None
x = (
torch.randn(
params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda"
)
* scale
)
residual = (
(torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None
)
timers = []
# unfused int8 impl.
timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label,
unfused_int8_impl, "unfused_int8_impl"))
bench_fn(
layer,
x,
residual,
torch.int8,
label,
sub_label,
unfused_int8_impl,
"unfused_int8_impl",
)
)
# unfused fp8 impl.
timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
unfused_fp8_impl, "unfused_fp8_impl"))
bench_fn(
layer,
x,
residual,
torch.float8_e4m3fn,
label,
sub_label,
unfused_fp8_impl,
"unfused_fp8_impl",
)
)
# fused int8 impl.
timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl,
"fused_int8_impl"))
bench_fn(
layer,
x,
residual,
torch.int8,
label,
sub_label,
fused_impl,
"fused_int8_impl",
)
)
# fused fp8 impl.
timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
fused_impl, "fused_fp8_impl"))
bench_fn(
layer,
x,
residual,
torch.float8_e4m3fn,
label,
sub_label,
fused_impl,
"fused_fp8_impl",
)
)
print_timers(timers)
@ -157,13 +209,12 @@ def print_timers(timers: Iterable[TMeasurement]):
def main():
torch.set_default_device('cuda')
torch.set_default_device("cuda")
bench_params = get_bench_params()
timers = []
for bp in tqdm(bench_params):
timers.extend(
bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
print_timers(timers)
# pickle all the results
@ -172,5 +223,5 @@ def main():
pkl.dump(timers, f)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -9,32 +9,39 @@ import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
optimized_dequantize_gemm)
dequantize_weight,
generic_dequantize_gemm,
get_int_dtype,
optimized_dequantize_gemm,
)
from vllm.utils import FlexibleArgumentParser
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def torch_mult(
input: torch.Tensor, # [..., in_features]
# [..., in_features]
input: torch.Tensor,
weights: torch.Tensor,
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
# [num_out_groups, 1, 1, 1]
scales: torch.Tensor,
) -> torch.Tensor:
output = F.linear(input, weights)
return output
def dequant_out_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
# [..., in_features]
input: torch.Tensor,
# [num_out_groups, num_in_groups, num_codebooks]
codes: torch.IntTensor,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks: torch.Tensor,
# [num_out_groups, 1, 1, 1]
scales: torch.Tensor,
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
if bias is None:
@ -46,40 +53,42 @@ def dequant_out_scale(
flattened_output *= b_scales
return flattened_output.view(orig_shape)
else:
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
def dequant_weight_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
# [..., in_features]
input: torch.Tensor,
# [num_out_groups, num_in_groups, num_codebooks]
codes: torch.IntTensor,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks: torch.Tensor,
# [num_out_groups, 1, 1, 1]
scales: torch.Tensor,
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
def dequant_no_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
# [..., in_features]
input: torch.Tensor,
# [num_out_groups, num_in_groups, num_codebooks]
codes: torch.IntTensor,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks: torch.Tensor,
# [num_out_groups, 1, 1, 1]
scales: torch.Tensor,
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
return F.linear(input, weights, bias)
@ -89,23 +98,26 @@ def dequant_no_scale(
# the generic pytorch version.
# Just visual comparison.
def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
n = int(parts.sum().item())
device = torch.device('cuda:0')
device = torch.device("cuda:0")
code_range = (1 << bits) // 2
ingroups = 8
codes = torch.randint(-code_range,
codes = torch.randint(
-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device)
device=device,
)
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
codebooks = torch.randn(
size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device)
device=device,
)
count = 0
for index in range(16):
@ -138,24 +150,25 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
def main():
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
# Add arguments
parser.add_argument("--nbooks",
type=int,
default=1,
help="Number of codebooks (default: 1)")
parser.add_argument("--bits",
parser.add_argument(
"--nbooks", type=int, default=1, help="Number of codebooks (default: 1)"
)
parser.add_argument(
"--bits",
type=int,
default=16,
help="Number of bits per code element (default: 16)")
help="Number of bits per code element (default: 16)",
)
parser.add_argument(
"--test",
type=bool,
default=False,
help="Run the decompression/dequant tester rather than benchmarking "
"(default: False)")
"(default: False)",
)
# Parse the arguments
args = parser.parse_args()
@ -165,7 +178,7 @@ def main():
bits = args.bits
if args.test:
dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
dequant_test(4096, torch.tensor((4096,)), nbooks, bits)
return
# Otherwise, benchmark.
@ -184,31 +197,54 @@ def main():
with open(filename, "w") as f:
sys.stdout = f
print('m | k | n | n parts', end='')
print("m | k | n | n parts", end="")
for method in methods:
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
print('')
print(f" | {method.__name__.replace('_', ' ')} (µs)", end="")
print("")
# These are reasonable prefill sizes.
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
(4096, (11008, 11008)), (11008, (4096, )))
ksandpartions = (
(4096, (4096, 4096, 4096)),
(4096, (4096,)),
(4096, (11008, 11008)),
(11008, (4096,)),
)
# reasonable ranges for m.
for m in [
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
128, 256, 512, 1024, 1536, 2048, 3072, 4096
1,
2,
4,
8,
10,
12,
14,
16,
24,
32,
48,
52,
56,
64,
96,
112,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]:
print(f'{m}', file=sys.__stdout__)
print(f"{m}", file=sys.__stdout__)
for ksp in ksandpartions:
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
methods)
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods)
sys.stdout = sys.__stdout__
def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
methods):
def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods):
# I didn't see visible improvements from increasing these, but feel free :)
num_warmup_trials = 1
num_trials = 1
@ -229,7 +265,7 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
)
n = parts.sum().item()
print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
print(f"{m} | {k} | {n} | {parts.tolist()}", end="")
for method in methods:
best_time_us = 1e20
@ -249,32 +285,36 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
if kernel_dur_us < best_time_us:
best_time_us = kernel_dur_us
print(f' | {kernel_dur_us:.0f}', end='')
print(f" | {kernel_dur_us:.0f}", end="")
print('')
print("")
def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor,
nbooks: int, bits: int, method) -> float:
def run_timing(
num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method
) -> float:
n = int(parts.sum().item())
device = torch.device('cuda:0')
device = torch.device("cuda:0")
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
code_range = (1 << bits) // 2
ingroups = 8
codes = torch.randint(-code_range,
codes = torch.randint(
-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device)
device=device,
)
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
codebooks = torch.randn(
size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device)
device=device,
)
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)

View File

@ -3,16 +3,21 @@
# Licensed under the MIT License.
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
MINIMUM_BITBLAS_VERSION)
MINIMUM_BITBLAS_VERSION,
)
try:
import bitblas
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
raise ImportError("bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
)
except ImportError as e:
bitblas_import_exception = e
raise ValueError("Trying to use the bitblas backend, but could not import"
raise ValueError(
"Trying to use the bitblas backend, but could not import"
f"with the following error: {bitblas_import_exception}. "
"Please install bitblas through the following command: "
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
@ -23,7 +28,8 @@ from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
from vllm.utils import FlexibleArgumentParser
parser = FlexibleArgumentParser(
description="Benchmark BitBLAS int4 on a specific target.")
description="Benchmark BitBLAS int4 on a specific target."
)
# Add arguments to the parser
parser.add_argument(
@ -32,10 +38,9 @@ parser.add_argument(
default=auto_detect_nvidia_target(),
help="Specify the target device for benchmarking.",
)
parser.add_argument("--group_size",
type=int,
default=None,
help="Group size for grouped quantization.")
parser.add_argument(
"--group_size", type=int, default=None, help="Group size for grouped quantization."
)
parser.add_argument(
"--A_dtype",
type=str,
@ -82,17 +87,17 @@ parser.add_argument(
choices=["nt", "nn"],
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
)
parser.add_argument("--with_bias",
action="store_true",
help="Include bias in the benchmark.")
parser.add_argument(
"--with_bias", action="store_true", help="Include bias in the benchmark."
)
parser.add_argument(
"--with_scaling",
action="store_true",
help="Include scaling factor in the quantization.",
)
parser.add_argument("--with_zeros",
action="store_true",
help="Include zeros in the quantization.")
parser.add_argument(
"--with_zeros", action="store_true", help="Include zeros in the quantization."
)
parser.add_argument(
"--zeros_mode",
type=str,
@ -170,8 +175,7 @@ shapes = [
]
# Build test shapes with all the shared arguments
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args))
for shape in shapes]
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) for shape in shapes]
benchmark_sets = []
benchmark_sets.extend(test_shapes)
@ -206,12 +210,12 @@ for config_key, values in benchmark_results.items():
func_name = args_split[0]
input_args_str = "-".join(args_split[1:])
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
col_widths[1] = max(col_widths[1],
len(input_args_str) + 2,
len(headers[1]) + 2)
col_widths[2] = max(col_widths[2],
col_widths[1] = max(col_widths[1], len(input_args_str) + 2, len(headers[1]) + 2)
col_widths[2] = max(
col_widths[2],
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
len(headers[2]) + 2)
len(headers[2]) + 2,
)
# break only if you want to measure widths from a single example;
# otherwise, let it loop over all items.
@ -232,5 +236,6 @@ for config_key, values in benchmark_results.items():
f"{values['BitBLAS_top20_latency']:.3f} ms",
]
row_str = "".join(
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)])
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]
)
print(row_str)

View File

@ -5,6 +5,7 @@ kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit
activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8)
and 16-bit activations.
"""
import nvtx
import torch
import torch.utils.benchmark as benchmark
@ -12,8 +13,7 @@ import torch.utils.benchmark as benchmark
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
fused_topk)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.scalar_type import scalar_types
from vllm.utils import FlexibleArgumentParser
@ -38,19 +38,27 @@ FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
dtype=torch.float8_e4m3fn
)
def bench_run(results: list[benchmark.Measurement], model: str,
num_experts: int, topk: int, per_act_token: bool,
per_out_ch: bool, mkn: tuple[int, int, int]):
def bench_run(
results: list[benchmark.Measurement],
model: str,
num_experts: int,
topk: int,
per_act_token: bool,
per_out_ch: bool,
mkn: tuple[int, int, int],
):
label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton"
sub_label = (
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
"MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch,
mkn))
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
model, num_experts, topk, per_act_token, per_out_ch, mkn
)
)
print(f"Testing: {sub_label}")
@ -64,18 +72,12 @@ def bench_run(results: list[benchmark.Measurement], model: str,
_, a_fp8_scale = ops.scaled_fp8_quant(a)
w1_fp8q = torch.empty((num_experts, 2 * n, k),
device=device,
dtype=torch.float8_e4m3fn)
w2_fp8q = torch.empty((num_experts, k, n),
device=device,
dtype=torch.float8_e4m3fn)
w1_fp8scale = torch.empty((num_experts, 1, 1),
device=device,
dtype=torch.float32)
w2_fp8scale = torch.empty((num_experts, 1, 1),
device=device,
dtype=torch.float32)
w1_fp8q = torch.empty(
(num_experts, 2 * n, k), device=device, dtype=torch.float8_e4m3fn
)
w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=torch.float8_e4m3fn)
w1_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
w2_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
for expert in range(num_experts):
w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert])
@ -91,26 +93,24 @@ def bench_run(results: list[benchmark.Measurement], model: str,
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
quant_blocksize = 16
w1_blockscale = torch.empty((num_experts, 2 * n, k // quant_blocksize),
w1_blockscale = torch.empty(
(num_experts, 2 * n, k // quant_blocksize),
device=device,
dtype=torch.float8_e4m3fn)
w2_blockscale = torch.empty((num_experts, k, n // quant_blocksize),
device=device,
dtype=torch.float8_e4m3fn)
dtype=torch.float8_e4m3fn,
)
w2_blockscale = torch.empty(
(num_experts, k, n // quant_blocksize), device=device, dtype=torch.float8_e4m3fn
)
# n_b_scales = 2 * n if per_out_ch else 1
# k_b_scales = k if per_out_ch else 1
w1_fp4 = torch.empty((num_experts, 2 * n, k // 2),
device=device,
dtype=torch.uint8)
w2_fp4 = torch.empty((num_experts, k, n // 2),
device=device,
dtype=torch.uint8)
w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), device=device, dtype=torch.uint8)
w2_fp4 = torch.empty((num_experts, k, n // 2), device=device, dtype=torch.uint8)
w1_gs = torch.empty((num_experts, ), device=device, dtype=torch.float32)
w2_gs = torch.empty((num_experts, ), device=device, dtype=torch.float32)
a1_gs = torch.ones((num_experts, ), device=device, dtype=torch.float32)
a2_gs = torch.ones((num_experts, ), device=device, dtype=torch.float32)
w1_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
w2_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
a1_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
a2_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
for expert in range(num_experts):
w1_e = w1[expert]
@ -121,17 +121,27 @@ def bench_run(results: list[benchmark.Measurement], model: str,
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
w1_e, w1_gs[expert])
w1_e, w1_gs[expert]
)
w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
w2_e, w2_gs[expert])
w2_e, w2_gs[expert]
)
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
a_fp8_scale: torch.Tensor, num_repeats: int):
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a_fp8_scale: torch.Tensor,
num_repeats: int,
):
for _ in range(num_repeats):
fused_experts(a,
fused_experts(
a,
w1,
w2,
topk_weights,
@ -139,18 +149,32 @@ def bench_run(results: list[benchmark.Measurement], model: str,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale)
a1_scale=a_fp8_scale,
)
def run_cutlass_moe_fp4(a: torch.Tensor, w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor, w1_blockscale: torch.Tensor,
w2_blockscale: torch.Tensor, w1_gs: torch.Tensor,
w2_gs: torch.Tensor, a1_gs: torch.Tensor,
a2_gs: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, m: int, n: int, k: int,
e: int, device: torch.device, num_repeats: int):
def run_cutlass_moe_fp4(
a: torch.Tensor,
w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w2_blockscale: torch.Tensor,
w1_gs: torch.Tensor,
w2_gs: torch.Tensor,
a1_gs: torch.Tensor,
a2_gs: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
num_repeats: int,
):
for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp4", color="green"):
cutlass_moe_fp4(a=a,
cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_fp4=w1_fp4,
@ -165,19 +189,32 @@ def bench_run(results: list[benchmark.Measurement], model: str,
n=n,
k=k,
e=num_experts,
device=device)
device=device,
)
def run_cutlass_from_graph(
a: torch.Tensor, a1_gscale: torch.Tensor, w1_fp4: torch.Tensor,
w1_blockscale: torch.Tensor, w1_alphas: torch.Tensor,
a2_gscale: torch.Tensor, w2_fp4: torch.Tensor,
w2_blockscale: torch.Tensor, w2_alphas: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int,
k: int, e: int, device: torch.device):
a: torch.Tensor,
a1_gscale: torch.Tensor,
w1_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alphas: torch.Tensor,
a2_gscale: torch.Tensor,
w2_fp4: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alphas: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return cutlass_moe_fp4(a=a,
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
@ -192,17 +229,24 @@ def bench_run(results: list[benchmark.Measurement], model: str,
n=n,
k=k,
e=num_experts,
device=device)
device=device,
)
def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, w1_scale: torch.Tensor,
def run_triton_from_graph(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a_fp8_scale: torch.Tensor):
a_fp8_scale: torch.Tensor,
):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return fused_experts(a,
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return fused_experts(
a,
w1,
w2,
topk_weights,
@ -210,7 +254,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale)
a1_scale=a_fp8_scale,
)
def replay_graph(graph, num_repeats):
for _ in range(num_repeats):
@ -220,7 +265,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
cutlass_stream = torch.cuda.Stream()
cutlass_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
run_cutlass_from_graph(a=a,
run_cutlass_from_graph(
a=a,
a1_gscale=a1_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
@ -235,15 +281,23 @@ def bench_run(results: list[benchmark.Measurement], model: str,
n=n,
k=k,
e=num_experts,
device=device)
device=device,
)
torch.cuda.synchronize()
triton_stream = torch.cuda.Stream()
triton_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph(a, w1_fp8q_notransp, w2_fp8q_notransp,
topk_weights, topk_ids, w1_fp8scale, w2_fp8scale,
a_fp8_scale)
run_triton_from_graph(
a,
w1_fp8q_notransp,
w2_fp8q_notransp,
topk_weights,
topk_ids,
w1_fp8scale,
w2_fp8scale,
a_fp8_scale,
)
torch.cuda.synchronize()
min_run_time = 5
@ -290,18 +344,27 @@ def bench_run(results: list[benchmark.Measurement], model: str,
}
# Warmup
run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights,
topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_warmup)
run_triton_moe(
a,
w1_fp8q_notransp,
w2_fp8q_notransp,
topk_weights,
topk_ids,
w1_fp8scale,
w2_fp8scale,
a_fp8_scale,
num_warmup,
)
results.append(
benchmark.Timer(
stmt=
"run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501
stmt="run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
# Warmup
replay_graph(triton_graph, num_warmup)
@ -313,23 +376,40 @@ def bench_run(results: list[benchmark.Measurement], model: str,
label=label,
sub_label=sub_label,
description="triton_moe_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
# Warmup
run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_gs,
w2_gs, a1_gs, a2_gs, topk_weights, topk_ids, m, n, k,
num_experts, device, num_warmup)
run_cutlass_moe_fp4(
a,
w1_fp4,
w2_fp4,
w1_blockscale,
w2_blockscale,
w1_gs,
w2_gs,
a1_gs,
a2_gs,
topk_weights,
topk_ids,
m,
n,
k,
num_experts,
device,
num_warmup,
)
results.append(
benchmark.Timer(
stmt=
"run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501
stmt="run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="cutlass_moe_fp4",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
# Warmup
replay_graph(cutlass_graph, num_warmup)
@ -341,7 +421,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
label=label,
sub_label=sub_label,
description="cutlass_moe_fp4_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
def main(args):
@ -369,8 +450,15 @@ def main(args):
for per_out_ch in PER_OUT_CH_OPTS:
for size_m in args.batch_sizes:
mkn = (size_m, size_k, size_n)
bench_run(results, model, num_experts, topk,
per_act_token, per_out_ch, mkn)
bench_run(
results,
model,
num_experts,
topk,
per_act_token,
per_out_ch,
mkn,
)
compare = benchmark.Compare(results)
compare.print()
@ -378,8 +466,8 @@ def main(args):
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark NVFP4 CUTLASS MOE across specified "
"models/shapes/batches")
description="Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches"
)
parser.add_argument(
"--models",
nargs="+",
@ -387,21 +475,14 @@ if __name__ == "__main__":
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES_MOE.keys(),
)
parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
parser.add_argument(
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
parser.add_argument("--limit-per-act-token",
nargs="+",
type=int,
default=[])
parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
args = parser.parse_args()

View File

@ -6,14 +6,18 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
from vllm.model_executor.layers.fused_moe.fused_moe import (
cutlass_moe_fp8,
fused_experts,
fused_topk)
fused_topk,
)
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = [
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite",
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m"
"nm-testing/Mixtral-8x7B-Instruct-v0.1",
"nm-testing/deepseekv2-lite",
"ibm-granite/granite-3.0-1b-a400m",
"ibm-granite/granite-3.0-3b-a800m",
]
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]
@ -24,19 +28,27 @@ PER_OUT_CH_OPTS = [False]
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
dtype=torch.float8_e4m3fn
)
def bench_run(results: list[benchmark.Measurement], model: str,
num_experts: int, topk: int, per_act_token: bool,
per_out_ch: bool, mkn: tuple[int, int, int]):
def bench_run(
results: list[benchmark.Measurement],
model: str,
num_experts: int,
topk: int,
per_act_token: bool,
per_out_ch: bool,
mkn: tuple[int, int, int],
):
label = "Quant Matmul"
sub_label = (
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
"MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch,
mkn))
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
model, num_experts, topk, per_act_token, per_out_ch, mkn
)
)
print(f"Testing: {sub_label}")
@ -50,35 +62,17 @@ def bench_run(results: list[benchmark.Measurement], model: str,
_, a_scale = ops.scaled_fp8_quant(a)
w1_q = torch.empty((num_experts, 2 * n, k),
device="cuda",
dtype=torch.float8_e4m3fn)
w2_q = torch.empty((num_experts, k, n),
device="cuda",
dtype=torch.float8_e4m3fn)
w1_scale = torch.empty((num_experts, 1, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1),
device="cuda",
dtype=torch.float32)
w1_q = torch.empty(
(num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn
)
w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn)
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
ab_strides1 = torch.full((num_experts, ),
k,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_experts, ),
2 * n,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_experts, ),
n,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_experts, ),
k,
device="cuda",
dtype=torch.int64)
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
for expert in range(num_experts):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
@ -91,14 +85,23 @@ def bench_run(results: list[benchmark.Measurement], model: str,
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score, topk, renormalize=False)
a, score, topk, renormalize=False
)
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
a_scale: torch.Tensor, num_repeats: int):
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a_scale: torch.Tensor,
num_repeats: int,
):
for _ in range(num_repeats):
fused_experts(a,
fused_experts(
a,
w1,
w2,
topk_weights,
@ -106,17 +109,27 @@ def bench_run(results: list[benchmark.Measurement], model: str,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale)
a1_scale=a_scale,
)
def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
ab_strides2: torch.Tensor, c_strides2: torch.Tensor,
num_repeats: int):
def run_cutlass_moe(
a: torch.Tensor,
a_scale: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
num_repeats: int,
):
for _ in range(num_repeats):
cutlass_moe_fp8(a,
cutlass_moe_fp8(
a,
w1,
w2,
w1_scale,
@ -127,18 +140,28 @@ def bench_run(results: list[benchmark.Measurement], model: str,
c_strides1,
ab_strides2,
c_strides2,
a1_scale=a_scale)
a1_scale=a_scale,
)
def run_cutlass_from_graph(
a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
a: torch.Tensor,
a_scale: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return cutlass_moe_fp8(a,
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return cutlass_moe_fp8(
a,
w1_q,
w2_q,
w1_scale,
@ -149,16 +172,24 @@ def bench_run(results: list[benchmark.Measurement], model: str,
c_strides1,
ab_strides2,
c_strides2,
a1_scale=a_scale)
a1_scale=a_scale,
)
def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, a_scale: torch.Tensor):
def run_triton_from_graph(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a_scale: torch.Tensor,
):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return fused_experts(a,
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return fused_experts(
a,
w1,
w2,
topk_weights,
@ -166,7 +197,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale)
a1_scale=a_scale,
)
def replay_graph(graph, num_repeats):
for _ in range(num_repeats):
@ -176,16 +208,35 @@ def bench_run(results: list[benchmark.Measurement], model: str,
cutlass_stream = torch.cuda.Stream()
cutlass_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale,
topk_weights, topk_ids, ab_strides1, c_strides1,
ab_strides2, c_strides2)
run_cutlass_from_graph(
a,
a_scale,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
)
torch.cuda.synchronize()
triton_stream = torch.cuda.Stream()
triton_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights,
topk_ids, w1_scale, w2_scale, a_scale)
run_triton_from_graph(
a,
w1_q_notransp,
w2_q_notransp,
topk_weights,
topk_ids,
w1_scale,
w2_scale,
a_scale,
)
torch.cuda.synchronize()
min_run_time = 5
@ -225,18 +276,27 @@ def bench_run(results: list[benchmark.Measurement], model: str,
}
# Warmup
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
w1_scale, w2_scale, a_scale, num_warmup)
run_triton_moe(
a,
w1_q_notransp,
w2_q_notransp,
topk_weights,
topk_ids,
w1_scale,
w2_scale,
a_scale,
num_warmup,
)
results.append(
benchmark.Timer(
stmt=
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
# Warmup
replay_graph(triton_graph, num_warmup)
@ -248,22 +308,35 @@ def bench_run(results: list[benchmark.Measurement], model: str,
label=label,
sub_label=sub_label,
description="triton_moe_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
# Warmup
run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2,
num_warmup)
run_cutlass_moe(
a,
a_scale,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
num_warmup,
)
results.append(
benchmark.Timer(
stmt=
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="grouped_gemm_moe",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
# Warmup
replay_graph(cutlass_graph, num_warmup)
@ -275,7 +348,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
label=label,
sub_label=sub_label,
description="grouped_gemm_moe_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
def main(args):
@ -303,8 +377,15 @@ def main(args):
for per_out_ch in PER_OUT_CH_OPTS:
for size_m in DEFAULT_BATCH_SIZES:
mkn = (size_m, size_k, size_n)
bench_run(results, model, num_experts, topk,
per_act_token, per_out_ch, mkn)
bench_run(
results,
model,
num_experts,
topk,
per_act_token,
per_out_ch,
mkn,
)
compare = benchmark.Compare(results)
compare.print()
@ -312,7 +393,8 @@ def main(args):
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches")
description="Benchmark Marlin across specified models/shapes/batches"
)
parser.add_argument(
"--models",
nargs="+",
@ -320,21 +402,14 @@ if __name__ == "__main__":
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES_MOE.keys(),
)
parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
parser.add_argument(
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
parser.add_argument("--limit-per-act-token",
nargs="+",
type=int,
default=[])
parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
args = parser.parse_args()

View File

@ -10,14 +10,16 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
@torch.inference_mode()
def main(num_tokens: int,
def main(
num_tokens: int,
hidden_size: int,
add_residual: bool,
dtype: torch.dtype,
seed: int = 0,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
num_iters: int = 100,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device("cuda")
@ -56,33 +58,35 @@ def main(num_tokens: int,
print(f"Kernel running time: {latency * 1000000:.3f} us")
if __name__ == '__main__':
parser = FlexibleArgumentParser(
description="Benchmark the layernorm kernel.")
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.")
parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192)
parser.add_argument("--add-residual", action="store_true")
parser.add_argument("--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="half")
parser.add_argument(
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--num-warmup-iters", type=int, default=5)
parser.add_argument("--num-iters",
parser.add_argument(
"--num-iters",
type=int,
default=100,
help="Number of benchmark iterations. "
"If --profile is set, this number is ignored")
"If --profile is set, this number is ignored",
)
args = parser.parse_args()
print(args)
main(num_tokens=args.num_tokens,
main(
num_tokens=args.num_tokens,
hidden_size=args.hidden_size,
add_residual=args.add_residual,
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
seed=args.seed,
do_profile=args.profile,
num_warmup_iters=args.num_warmup_iters,
num_iters=args.num_iters)
num_iters=args.num_iters,
)

View File

@ -20,18 +20,36 @@ from weight_shapes import WEIGHT_SHAPES
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand,
lora_shrink)
from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
_LORA_B_PTR_DICT)
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_TP_SIZES = [1]
DEFAULT_BATCH_SIZES = [
1, 16, 32, 64, 128, 192, 256, 320, 384, 448, 512, 640, 768, 896, 1024,
2048, 3072, 4096, 5120, 6144, 7168, 8192
1,
16,
32,
64,
128,
192,
256,
320,
384,
448,
512,
640,
768,
896,
1024,
2048,
3072,
4096,
5120,
6144,
7168,
8192,
]
DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384]
DEFAULT_LORA_RANKS = [16]
@ -52,12 +70,9 @@ def dtype_to_str(dtype: torch.dtype):
raise ValueError(f"Unsupported dtype {dtype}")
def make_rand_lora_weight_tensor(k: int,
n: int,
num_loras: int,
dtype: torch.dtype,
device: str = "cuda") -> torch.Tensor:
def make_rand_lora_weight_tensor(
k: int, n: int, num_loras: int, dtype: torch.dtype, device: str = "cuda"
) -> torch.Tensor:
# LoRA weights column major
return torch.rand((num_loras, n, k), dtype=dtype).to(device)
@ -78,18 +93,15 @@ def make_rand_tensors(
A = torch.rand(a_shape, dtype=a_dtype).to(device)
# LoRA weights column major
Bs = [
torch.rand(b_shape, dtype=b_dtype).to(device)
for _ in range(num_slices)
]
Bs = [torch.rand(b_shape, dtype=b_dtype).to(device) for _ in range(num_slices)]
C = torch.zeros(c_shape, dtype=c_dtype).to(device)
return A, Bs, C
def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int,
sort_by_lora_id: bool,
device: str) -> torch.Tensor:
def make_prompt_lora_mapping(
num_prompts: int, num_active_loras: int, sort_by_lora_id: bool, device: str
) -> torch.Tensor:
"""
All prompts are mapped to a LoRA ID in range [0, num_active_loras).
where 0 refers to first lora, 1 refers to second lora and so on.
@ -97,9 +109,7 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int,
assert num_active_loras > 0
if not sort_by_lora_id:
return torch.randint(0,
num_active_loras, (num_prompts, ),
dtype=torch.long)
return torch.randint(0, num_active_loras, (num_prompts,), dtype=torch.long)
# Divide LoRAs equally and in order.
part_size = num_prompts // num_active_loras
@ -110,14 +120,18 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int,
while len(prompt_lora_mapping) < num_prompts:
prompt_lora_mapping.extend([lora_id] * part_size)
lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id
return torch.tensor(prompt_lora_mapping[:num_prompts],
dtype=torch.long,
device=device)
return torch.tensor(
prompt_lora_mapping[:num_prompts], dtype=torch.long, device=device
)
def make_token_lora_mapping(num_tokens: int, num_prompts: int,
def make_token_lora_mapping(
num_tokens: int,
num_prompts: int,
prompt_lora_mapping: torch.Tensor,
seq_len_tensor: torch.Tensor, device: str):
seq_len_tensor: torch.Tensor,
device: str,
):
"""
Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor
"""
@ -136,11 +150,15 @@ def make_token_lora_mapping(num_tokens: int, num_prompts: int,
return torch.tensor(token_lora_mapping, dtype=torch.long, device=device)
def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
def ref_group_gemm(
ref_out: torch.Tensor,
input: torch.Tensor,
lora_weights: list[torch.Tensor],
seq_lens_cpu: torch.Tensor,
prompt_lora_mapping_cpu: torch.Tensor, scaling: float,
add_inputs: Optional[bool]):
prompt_lora_mapping_cpu: torch.Tensor,
scaling: float,
add_inputs: Optional[bool],
):
"""
Torch group gemm reference implementation to test correctness of
benchmarking operations.
@ -149,7 +167,7 @@ def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
out_list = []
current_offset = 0
for lora_index, b_length in zip(range(batches), seq_lens_cpu):
x = input[current_offset:b_length + current_offset, :]
x = input[current_offset : b_length + current_offset, :]
current_offset += b_length
w = lora_weights[prompt_lora_mapping_cpu[lora_index]]
result = torch.nn.functional.linear(x, w)
@ -168,6 +186,7 @@ class OpType(Enum):
"""
LoRA Ops to benchmark and its properties.
"""
LORA_SHRINK = auto()
LORA_EXPAND = auto()
@ -188,8 +207,9 @@ class OpType(Enum):
def num_slices(self) -> list[int]:
return [1, 2, 3]
def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
lora_rank: int) -> tuple[int, int, int]:
def mkn(
self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int
) -> tuple[int, int, int]:
num_tokens = batch_size * seq_length
if self.is_shrink_fn():
m = num_tokens
@ -215,9 +235,14 @@ class OpType(Enum):
return torch.float32, op_dtype, op_dtype
def matmul_shapes(
self, batch_size: int, seq_length: int, hidden_size: int,
lora_rank: int, num_loras: int,
num_slices: int) -> tuple[tuple[int], tuple[int], tuple[int]]:
self,
batch_size: int,
seq_length: int,
hidden_size: int,
lora_rank: int,
num_loras: int,
num_slices: int,
) -> tuple[tuple[int], tuple[int], tuple[int]]:
"""
Given num_slices, return the shapes of the A, B, and C matrices
in A x B = C, for the op_type
@ -241,9 +266,13 @@ class OpType(Enum):
raise ValueError(f"Unrecognized optype {self}")
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
def run_ref_group_gemm(
self,
output: torch.Tensor,
input: torch.Tensor,
lora_weights: list[torch.Tensor],
**kwargs) -> Callable:
**kwargs,
) -> Callable:
"""Each benchmark operation expects the input, lora_weights and outputs
in a slightly different format. Refer to self.matmul_shapes().
run_ref_group_gemm accounts for those differences in executing a
@ -253,19 +282,22 @@ class OpType(Enum):
num_slices = len(lora_weights)
if self in [OpType.LORA_SHRINK]:
for slice_idx in range(num_slices):
ref_group_gemm(ref_out=output[slice_idx, :],
ref_group_gemm(
ref_out=output[slice_idx, :],
input=input,
lora_weights=lora_weights[slice_idx],
**kwargs)
**kwargs,
)
elif self in [OpType.LORA_EXPAND]:
hidden_size = lora_weights[0].shape[1]
for slice_idx in range(num_slices):
slice_offset = slice_idx * hidden_size
ref_group_gemm(
ref_out=output[:, slice_offset:slice_offset + hidden_size],
ref_out=output[:, slice_offset : slice_offset + hidden_size],
input=input[slice_idx].clone().to(dtype=w_dtype),
lora_weights=lora_weights[slice_idx],
**kwargs)
**kwargs,
)
else:
raise ValueError(f"Unrecognized optype {self}")
@ -275,6 +307,7 @@ class BenchmarkContext:
"""
LoRA benchmark context
"""
batch_size: int
hidden_size: int
num_loras: int
@ -299,17 +332,18 @@ class BenchmarkContext:
return f"lora-{self.dtype}"
def bench_sublabel(self, op_type: OpType) -> str:
m, k, n = op_type.mkn(self.batch_size, self.seq_length,
self.hidden_size, self.lora_rank)
m, k, n = op_type.mkn(
self.batch_size, self.seq_length, self.hidden_size, self.lora_rank
)
desc = {
'bs': self.batch_size,
'sl': self.seq_length,
'm': m,
'k': k,
'n': n,
'num_loras': self.num_loras,
'sort_by_lora': self.sort_by_lora_id,
'num_slices': self.num_slices,
"bs": self.batch_size,
"sl": self.seq_length,
"m": m,
"k": k,
"n": n,
"num_loras": self.num_loras,
"sort_by_lora": self.sort_by_lora_id,
"num_slices": self.num_slices,
}
return json.dumps(desc)
@ -319,6 +353,7 @@ class BenchmarkTensors:
"""
Input/Output tensors used for benchmarks
"""
# matmul tensors
input: torch.Tensor
lora_weights_lst: list[torch.Tensor]
@ -330,23 +365,29 @@ class BenchmarkTensors:
prompt_lora_mapping: torch.Tensor
def io_types(self) -> str:
return (f"{dtype_to_str(self.input.dtype)}x"
return (
f"{dtype_to_str(self.input.dtype)}x"
f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>"
f"{dtype_to_str(self.output.dtype)}")
f"{dtype_to_str(self.output.dtype)}"
)
@staticmethod
def make(ctx: BenchmarkContext,
op_type: OpType,
device: str = "cuda") -> "BenchmarkTensors":
def make(
ctx: BenchmarkContext, op_type: OpType, device: str = "cuda"
) -> "BenchmarkTensors":
# Make input / output matmul tensors.
a_shape, b_shape, c_shape = op_type.matmul_shapes(
ctx.batch_size, ctx.seq_length, ctx.hidden_size, ctx.lora_rank,
ctx.num_loras, ctx.num_slices)
ctx.batch_size,
ctx.seq_length,
ctx.hidden_size,
ctx.lora_rank,
ctx.num_loras,
ctx.num_slices,
)
a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype)
input_tensor, lora_weights, output_tensor = \
make_rand_tensors(a_shape, b_shape, c_shape, a_type, b_type, c_type,
num_slices = ctx.num_slices)
input_tensor, lora_weights, output_tensor = make_rand_tensors(
a_shape, b_shape, c_shape, a_type, b_type, c_type, num_slices=ctx.num_slices
)
# Make metadata tensors.
# Keep the metadata tensors in the CPU for further processing if needed.
@ -356,27 +397,38 @@ class BenchmarkTensors:
# Make metadata tensors involved in correctness testing.
# Prepare seq lens tensor
seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1,
(ctx.batch_size, ))
seq_len_tensor = torch.randint(
ctx.seq_length, ctx.seq_length + 1, (ctx.batch_size,)
)
assert total_tokens == seq_len_tensor.sum()
# Prepare prompt lora indices tensor
prompt_lora_indices_tensor = make_prompt_lora_mapping(
ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu")
ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu"
)
# Make LoRAKernelMeta
token_lora_indices_tensor = make_token_lora_mapping(
total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
seq_len_tensor, "cpu")
total_tokens,
ctx.batch_size,
prompt_lora_indices_tensor,
seq_len_tensor,
"cpu",
)
lora_kernel_meta = LoRAKernelMeta.make(
max_loras=ctx.num_loras,
max_num_tokens=token_lora_indices_tensor.size(0),
device="cpu")
lora_kernel_meta.prepare_tensors(
token_lora_mapping=token_lora_indices_tensor)
device="cpu",
)
lora_kernel_meta.prepare_tensors(token_lora_mapping=token_lora_indices_tensor)
return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
lora_kernel_meta, seq_len_tensor,
prompt_lora_indices_tensor)
return BenchmarkTensors(
input_tensor,
lora_weights,
output_tensor,
lora_kernel_meta,
seq_len_tensor,
prompt_lora_indices_tensor,
)
def sanity_check(self) -> None:
"""
@ -386,7 +438,7 @@ class BenchmarkTensors:
# check metadata tensors
assert torch.sum(self.seq_lens) == num_tokens
num_seqs = self.seq_lens.shape[0]
#assert self.seq_start_loc.shape[0] == num_seqs
# assert self.seq_start_loc.shape[0] == num_seqs
assert self.prompt_lora_mapping.shape[0] == num_seqs
assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens
@ -430,8 +482,11 @@ class BenchmarkTensors:
_, num_tokens, _, num_slices = self.metadata()
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
i_shape, lw_shape, o_shape = (
self.input.shape,
self.lora_weights_lst[0].shape,
self.output.shape,
)
# Expected input shape [num_tokens, hidden_size]
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
@ -445,16 +500,17 @@ class BenchmarkTensors:
assert o_shape == (num_slices, num_tokens, lora_rank)
return {
'inputs': self.input,
'lora_a_weights': self.lora_weights_lst,
'output_tensor': self.output,
'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
'token_indices_sorted_by_lora_ids':
self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
'lora_ids': self.lora_kernel_meta.active_lora_ids,
'scaling': 1.0,
"inputs": self.input,
"lora_a_weights": self.lora_weights_lst,
"output_tensor": self.output,
"token_lora_mapping": self.lora_kernel_meta.token_lora_mapping,
"token_indices_sorted_by_lora_ids": (
self.lora_kernel_meta.token_indices_sorted_by_lora_ids
),
"num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora,
"lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc,
"lora_ids": self.lora_kernel_meta.active_lora_ids,
"scaling": 1.0,
}
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
@ -464,8 +520,11 @@ class BenchmarkTensors:
_, num_tokens, _, num_slices = self.metadata()
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
i_shape, lw_shape, o_shape = (
self.input.shape,
self.lora_weights_lst[0].shape,
self.output.shape,
)
# Expected input shape : [num_slices, num_tokens, lora_rank]
assert len(i_shape) == 3
assert i_shape[0] == num_slices
@ -480,22 +539,23 @@ class BenchmarkTensors:
assert o_shape == (num_tokens, hidden_size * num_slices)
return {
'inputs': self.input,
'lora_b_weights': self.lora_weights_lst,
'output_tensor': self.output,
'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
'token_indices_sorted_by_lora_ids':
self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
'lora_ids': self.lora_kernel_meta.active_lora_ids,
'offset_start': 0,
'add_inputs': add_inputs,
"inputs": self.input,
"lora_b_weights": self.lora_weights_lst,
"output_tensor": self.output,
"token_lora_mapping": self.lora_kernel_meta.token_lora_mapping,
"token_indices_sorted_by_lora_ids": (
self.lora_kernel_meta.token_indices_sorted_by_lora_ids
),
"num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora,
"lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc,
"lora_ids": self.lora_kernel_meta.active_lora_ids,
"offset_start": 0,
"add_inputs": add_inputs,
}
def bench_fn_kwargs(self,
op_type: OpType,
add_inputs: Optional[bool] = None) -> dict[str, Any]:
def bench_fn_kwargs(
self, op_type: OpType, add_inputs: Optional[bool] = None
) -> dict[str, Any]:
if op_type.is_shrink_fn():
assert add_inputs is None
else:
@ -507,8 +567,9 @@ class BenchmarkTensors:
return self.as_lora_expand_kwargs(add_inputs)
raise ValueError(f"Unrecognized optype {self}")
def test_correctness(self, op_type: OpType,
expand_fn_add_inputs: Optional[bool]) -> bool:
def test_correctness(
self, op_type: OpType, expand_fn_add_inputs: Optional[bool]
) -> bool:
"""
Test correctness of op_type implementation against a grouped gemm
reference implementation.
@ -518,8 +579,7 @@ class BenchmarkTensors:
ref_output = self.output.clone()
self.output.zero_()
op_type.bench_fn()(
**self.bench_fn_kwargs(op_type, expand_fn_add_inputs))
op_type.bench_fn()(**self.bench_fn_kwargs(op_type, expand_fn_add_inputs))
op_type.run_ref_group_gemm(
ref_output,
@ -528,7 +588,8 @@ class BenchmarkTensors:
seq_lens_cpu=seq_lens_cpu,
prompt_lora_mapping_cpu=prompt_lora_mapping_cpu,
scaling=1.0,
add_inputs=expand_fn_add_inputs)
add_inputs=expand_fn_add_inputs,
)
rtol, atol = {
torch.float16: (6e-2, 6e-2),
@ -539,13 +600,14 @@ class BenchmarkTensors:
return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol)
def bench_optype(ctx: BenchmarkContext,
def bench_optype(
ctx: BenchmarkContext,
arg_pool_size: int,
op_type: OpType,
cuda_graph_nops: Optional[int] = None,
expand_fn_add_inputs: Optional[bool] = None,
test_correctness: bool = False) -> TMeasurement:
test_correctness: bool = False,
) -> TMeasurement:
assert arg_pool_size >= 1
if op_type.is_shrink_fn():
assert expand_fn_add_inputs is None
@ -553,17 +615,17 @@ def bench_optype(ctx: BenchmarkContext,
assert expand_fn_add_inputs is not None
# BenchmarkContext -> BenchmarkTensors
bench_tensors : list[BenchmarkTensors] = \
[BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)]
bench_tensors: list[BenchmarkTensors] = [
BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)
]
for bt in bench_tensors:
bt.sanity_check()
# Test correctness of our implementation.
if test_correctness:
assert all([
bt.test_correctness(op_type, expand_fn_add_inputs)
for bt in bench_tensors
])
assert all(
[bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors]
)
# BenchmarkTensors -> dict (kwargs)
kwargs_list = [
@ -585,26 +647,33 @@ def bench_optype(ctx: BenchmarkContext,
for k, v in _kwargs.items():
kwargs[k].values.append(v)
describe_args = (f"add_inputs={expand_fn_add_inputs}"
if expand_fn_add_inputs is not None else "")
description = (
f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})")
describe_args = (
f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else ""
)
description = f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})"
cuda_graph_params = None
if cuda_graph_nops:
cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
timer = None
with Bench(cuda_graph_params,
ctx.bench_label(), ctx.bench_sublabel(op_type), description,
op_type.bench_fn(), **kwargs) as bench:
with Bench(
cuda_graph_params,
ctx.bench_label(),
ctx.bench_sublabel(op_type),
description,
op_type.bench_fn(),
**kwargs,
) as bench:
timer = bench.run()
return timer
def bench_torch_mm(ctx: BenchmarkContext,
def bench_torch_mm(
ctx: BenchmarkContext,
arg_pool_size: int,
op_type: OpType,
cuda_graph_nops: Optional[int] = None) -> TMeasurement:
cuda_graph_nops: Optional[int] = None,
) -> TMeasurement:
"""
Benchmark basic torch.mm as a roofline.
@ -614,11 +683,13 @@ def bench_torch_mm(ctx: BenchmarkContext,
input op_type is used in determining the m, k, n dimensions for the matmul.
"""
batch_size, hidden_size, lora_rank, seq_length, dtype = (ctx.batch_size,
batch_size, hidden_size, lora_rank, seq_length, dtype = (
ctx.batch_size,
ctx.hidden_size,
ctx.lora_rank,
ctx.seq_length,
ctx.dtype)
ctx.dtype,
)
m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank)
# For a fairer comparison.
@ -632,18 +703,24 @@ def bench_torch_mm(ctx: BenchmarkContext,
Cs.append(torch.rand((m, n), dtype=dtype).to("cuda"))
# Make torch.mm kwargs
mm_kwargs = {'input': ArgPool(As), 'mat2': ArgPool(Bs), 'out': ArgPool(Cs)}
mm_kwargs = {"input": ArgPool(As), "mat2": ArgPool(Bs), "out": ArgPool(Cs)}
description = (
f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}"
f"x{dtype_to_str(dtype)}"
f"=>{dtype_to_str(dtype)})")
f"=>{dtype_to_str(dtype)})"
)
cuda_graph_params = None
if cuda_graph_nops:
cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
with Bench(cuda_graph_params, ctx.bench_label(),
ctx.bench_sublabel(op_type), description, torch.mm,
**mm_kwargs) as bench:
with Bench(
cuda_graph_params,
ctx.bench_label(),
ctx.bench_sublabel(op_type),
description,
torch.mm,
**mm_kwargs,
) as bench:
return bench.run()
@ -660,8 +737,7 @@ def use_cuda_graph_recommendation() -> str:
"""
def print_timers(timers: list[TMeasurement],
args: Optional[argparse.Namespace] = None):
def print_timers(timers: list[TMeasurement], args: Optional[argparse.Namespace] = None):
compare = TBenchmark.Compare(timers)
compare.print()
@ -670,22 +746,23 @@ def print_timers(timers: list[TMeasurement],
f"Note : The timings reported above is for {args.cuda_graph_nops} "
"consecutive invocations of the benchmarking functions. "
f"Please divide by {args.cuda_graph_nops} for single invocation "
"timings.")
"timings."
)
print("Note on Comparison with torch.mm : The torch.mm numbers are "
print(
"Note on Comparison with torch.mm : The torch.mm numbers are "
"benchmark numbers of a simple matmul emulating the single lora "
"case. It is provided as a roofline for comparing our LoRA Kernel "
"implementations. It is expected that the LoRA kernels will be "
"slower than torch.mm in cases where num_loras is big. But for "
"small num_loras the goal should be to match the torch.mm numbers.")
"small num_loras the goal should be to match the torch.mm numbers."
)
def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
if args.cuda_graph_nops is not None:
assert args.cuda_graph_nops > 0
print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA "
"Graph")
print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA Graph")
else:
print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}")
@ -697,21 +774,30 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
for bench_op in bench_ops:
for num_slices in bench_op.num_slices():
_ctx = bench_ctx.with_seq_length(seq_len).with_num_slices(
num_slices)
num_slices
)
# Benchmark torch.mm as a roofline
seq_len_timers.append(
bench_torch_mm(_ctx, args.arg_pool_size, bench_op,
args.cuda_graph_nops))
bench_torch_mm(
_ctx, args.arg_pool_size, bench_op, args.cuda_graph_nops
)
)
# Benchmark bench_op
expand_fn_add_inputs = [
None
] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs
expand_fn_add_inputs = (
[None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs
)
for add_input_arg in expand_fn_add_inputs:
seq_len_timers.append(
bench_optype(_ctx, args.arg_pool_size, bench_op,
args.cuda_graph_nops, add_input_arg,
args.test_correctness))
bench_optype(
_ctx,
args.arg_pool_size,
bench_op,
args.cuda_graph_nops,
add_input_arg,
args.test_correctness,
)
)
print_timers(seq_len_timers)
timers.extend(seq_len_timers)
@ -733,13 +819,17 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
pickle.dump(timers, f)
def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int],
args: argparse.Namespace) -> list[BenchmarkContext]:
def as_benchmark_contexts(
hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace
) -> list[BenchmarkContext]:
ctxs: list[BenchmarkContext] = []
for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa
args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras,
args.sort_by_lora_id):
args.batch_sizes,
list(hidden_sizes),
lora_ranks,
args.num_loras,
args.sort_by_lora_id,
):
ctxs.append(
BenchmarkContext(
batch_size=batch_size,
@ -747,13 +837,16 @@ def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int],
lora_rank=lora_rank,
num_loras=num_loras,
num_active_loras=args.num_active_loras
if args.num_active_loras else num_loras,
if args.num_active_loras
else num_loras,
# To be filled based on the OpType to benchmark
seq_length=None,
sort_by_lora_id=sort_by_lora_id,
dtype=args.dtype,
# To be filled based on the OpType to benchmark
num_slices=None))
num_slices=None,
)
)
return ctxs
@ -761,13 +854,16 @@ def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int],
def run_list_bench(args: argparse.Namespace):
print(args)
print("List bench :\n"
print(
"List bench :\n"
f" Hidden Sizes {args.hidden_sizes}"
f" LoRA Ranks {args.lora_ranks}")
f" LoRA Ranks {args.lora_ranks}"
)
# Get all benchmarking contexts
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args)
hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args
)
run(args, bench_contexts)
@ -776,19 +872,22 @@ def run_range_bench(args: argparse.Namespace):
print(args)
hidden_sizes = list(
range(args.hidden_sizes_start, args.hidden_sizes_end + 1,
args.hidden_sizes_increment))
range(
args.hidden_sizes_start,
args.hidden_sizes_end + 1,
args.hidden_sizes_increment,
)
)
lora_ranks = list(
range(args.lora_ranks_start, args.lora_ranks_end + 1,
args.lora_ranks_increment))
range(args.lora_ranks_start, args.lora_ranks_end + 1, args.lora_ranks_increment)
)
print("Range bench :\n"
f" Hidden Sizes {hidden_sizes}"
f" LoRA Ranks {lora_ranks}")
print(f"Range bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {lora_ranks}")
# Get all benchmarking contexts
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args)
hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args
)
run(args, bench_contexts)
@ -806,21 +905,19 @@ def run_model_bench(args: argparse.Namespace):
# Get all hidden sizes
hidden_sizes: set[int] = set()
for model_name, tp_size in product(args.models, args.tp_sizes):
hidden_sizes = hidden_sizes.union(
hidden_sizes_from_model(model_name, tp_size))
hidden_sizes = hidden_sizes.union(hidden_sizes_from_model(model_name, tp_size))
print("Model bench :\n"
f" Hidden Sizes {hidden_sizes}"
f" LoRA Ranks {args.lora_ranks}")
print(f"Model bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {args.lora_ranks}")
# Get all benchmarking contexts
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args)
hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args
)
run(args, bench_contexts)
if __name__ == '__main__':
if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "torch.float16":
@ -830,14 +927,15 @@ if __name__ == '__main__':
raise ValueError("unsupported dtype")
def get_bool(s: str) -> bool:
return s.lower() in ['true', '1']
return s.lower() in ["true", "1"]
def add_common_command_args(p: argparse.ArgumentParser):
p.add_argument(
"--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['torch.float16', 'torch.bfloat16']")
help="Available options are ['torch.float16', 'torch.bfloat16']",
)
p.add_argument(
"--arg-pool-size",
@ -845,56 +943,66 @@ if __name__ == '__main__':
default=32,
help="Run profiles with a pool of input/output/meta tensors instead"
"of simply reusing the same tensors for all runs. A bigger arg-pool"
"mitigates hardware caching effects during benchmarking.")
"mitigates hardware caching effects during benchmarking.",
)
p.add_argument(
"--cuda-graph-nops",
type=int,
help=("when set profiling is done using cudagraph, "
help=(
"when set profiling is done using cudagraph, "
"with the given number of operations in a graph."
"Note that the measurement returned is the time "
"taken for N consecutive executions of the benchmarking "
"functions, where N is the value of this argument."))
p.add_argument("--num-loras",
nargs="+",
type=int,
default=DEFAULT_NUM_LORAS)
p.add_argument("--num-active-loras",
"functions, where N is the value of this argument."
),
)
p.add_argument("--num-loras", nargs="+", type=int, default=DEFAULT_NUM_LORAS)
p.add_argument(
"--num-active-loras",
type=int,
default=None,
help="Active LoRAs. When None, all LoRAs are active")
p.add_argument("--sort-by-lora-id",
nargs="+",
type=get_bool,
default=DEFAULT_SORT_BY_LORA_IDS)
p.add_argument("--op-types",
nargs="+",
type=OpType.from_str,
default=list(OpType))
p.add_argument('--seq-lengths',
nargs="+",
type=int,
default=DEFAULT_SEQ_LENGTHS)
p.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
p.add_argument("--expand-fn-add-inputs",
nargs="+",
type=get_bool,
default=DEFAULT_EXPAND_FN_ADD_INPUTS)
help="Active LoRAs. When None, all LoRAs are active",
)
p.add_argument(
'-o',
'--output-directory',
"--sort-by-lora-id",
nargs="+",
type=get_bool,
default=DEFAULT_SORT_BY_LORA_IDS,
)
p.add_argument(
"--op-types", nargs="+", type=OpType.from_str, default=list(OpType)
)
p.add_argument(
"--seq-lengths", nargs="+", type=int, default=DEFAULT_SEQ_LENGTHS
)
p.add_argument(
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
)
p.add_argument(
"--expand-fn-add-inputs",
nargs="+",
type=get_bool,
default=DEFAULT_EXPAND_FN_ADD_INPUTS,
)
p.add_argument(
"-o",
"--output-directory",
type=str,
help=("Output directory to store a the list of benchmarking"
"TMeasurement objects as a pickle file"))
help=(
"Output directory to store a the list of benchmarking"
"TMeasurement objects as a pickle file"
),
)
p.add_argument(
"--test-correctness",
action='store_true',
help=("When enabled, the benchmarking functions are tested"
"for correctness before the actual benchmarking"))
action="store_true",
help=(
"When enabled, the benchmarking functions are tested"
"for correctness before the actual benchmarking"
),
)
parser = FlexibleArgumentParser(
description=f"""
@ -910,50 +1018,45 @@ Benchmark LoRA kernels:
range_bench example:
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter)
formatter_class=argparse.RawTextHelpFormatter,
)
subparsers = parser.add_subparsers(dest="cmd", required=True)
list_parser = subparsers.add_parser("list_bench")
list_parser.add_argument("--hidden-sizes",
nargs="+",
type=int,
default=DEFAULT_HIDDEN_SIZES)
list_parser.add_argument("--lora-ranks",
nargs="+",
type=int,
default=DEFAULT_LORA_RANKS)
list_parser.add_argument(
"--hidden-sizes", nargs="+", type=int, default=DEFAULT_HIDDEN_SIZES
)
list_parser.add_argument(
"--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS
)
add_common_command_args(list_parser)
list_parser.set_defaults(func=run_list_bench)
range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--hidden-sizes-start", type=int, required=True)
range_parser.add_argument("--hidden-sizes-end", type=int, required=True)
range_parser.add_argument("--hidden-sizes-increment",
type=int,
required=True)
range_parser.add_argument("--hidden-sizes-increment", type=int, required=True)
range_parser.add_argument("--lora-ranks-start", type=int, required=True)
range_parser.add_argument("--lora-ranks-end", type=int, required=True)
range_parser.add_argument("--lora-ranks-increment",
type=int,
required=True)
range_parser.add_argument("--lora-ranks-increment", type=int, required=True)
add_common_command_args(range_parser)
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models",
model_parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys())
model_parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
model_parser.add_argument("--lora-ranks",
nargs="+",
type=int,
default=DEFAULT_LORA_RANKS)
choices=WEIGHT_SHAPES.keys(),
)
model_parser.add_argument(
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
)
model_parser.add_argument(
"--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS
)
add_common_command_args(model_parser)
model_parser.set_defaults(func=run_model_bench)

View File

@ -20,12 +20,18 @@ from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales,
marlin_zero_points)
GPTQ_MARLIN_MAX_PARALLEL,
GPTQ_MARLIN_MIN_THREAD_N,
marlin_permute_scales,
marlin_zero_points,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace)
MarlinWorkspace,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_rows, quantize_weights)
pack_rows,
quantize_weights,
)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser
@ -82,12 +88,14 @@ def rand_data(shape, dtype=torch.float16, scale=1):
return torch.randint(-15, 15, shape, dtype=dtype, device="cuda")
def quantize_and_pack(atype: torch.dtype,
def quantize_and_pack(
atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
zero_points: bool = False):
zero_points: bool = False,
):
assert wtype.is_integer(), "TODO: support floating point weights"
w_ref, w_q, w_s, w_zp = quantize_weights(
@ -96,21 +104,24 @@ def quantize_and_pack(atype: torch.dtype,
group_size=group_size,
zero_points=zero_points,
# to match how the kernel applies zps
ref_zero_points_after_scales=True)
ref_zero_points_after_scales=True,
)
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
return w_ref, w_q, w_s, w_zp
def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
group_size: Optional[int]) -> list[BenchmarkTensors]:
def create_bench_tensors(
shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]
) -> list[BenchmarkTensors]:
m, n, k = shape
# we want to make sure that weights don't fit into L2 cache between runs so
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
# so we target total weight size > 2*50mb
num_weights = math.ceil(2 * 50 * 1024**2 * 8 /
(k * n * types.weight_type.size_bits))
num_weights = math.ceil(
2 * 50 * 1024**2 * 8 / (k * n * types.weight_type.size_bits)
)
a = rand_data((m, k), types.act_type, scale=5)
@ -124,8 +135,13 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
w = w.to(torch.float16)
w_ref, w_q_packed, w_s, w_zp = quantize_and_pack(
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
types.group_zero_type is not None)
a.dtype,
w,
types.weight_type,
types.group_scale_type,
group_size,
types.group_zero_type is not None,
)
if not a.dtype.is_floating_point:
aiinfo = torch.iinfo(a.dtype)
@ -133,13 +149,20 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
w_ref = w_ref.to(torch.float32)
w_ch_s = None if types.channel_scale_type is None else\
rand_data((n,), types.channel_scale_type)
w_tok_s = None if types.token_scale_type is None else\
rand_data((m,), types.token_scale_type)
w_ch_s = (
None
if types.channel_scale_type is None
else rand_data((n,), types.channel_scale_type)
)
w_tok_s = (
None
if types.token_scale_type is None
else rand_data((m,), types.token_scale_type)
)
benchmark_tensors.append(
BenchmarkTensors(w_ref=w_ref,
BenchmarkTensors(
w_ref=w_ref,
a=a,
w_q=w_q_packed,
wtype=types.weight_type,
@ -147,7 +170,9 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
w_g_zp=w_zp,
group_size=group_size,
w_ch_s=w_ch_s,
w_tok_s=w_tok_s))
w_tok_s=w_tok_s,
)
)
return benchmark_tensors
@ -170,38 +195,44 @@ def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable:
scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t()
return lambda: ops.cutlass_scaled_mm(
bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16)
bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16
)
def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
device = bt.a.device
workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
workspace = MarlinWorkspace(
bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
)
if bt.w_g_zp is None:
w_zp = torch.empty(0, dtype=torch.int, device=device)
else:
w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0],
bt.w_ref.shape[1], bt.wtype.size_bits)
w_zp = marlin_zero_points(
bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
)
if bt.group_size is None:
w_s = torch.tensor([], device="cuda", dtype=torch.half)
else:
w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0],
bt.w_ref.shape[1], bt.group_size)
w_s = marlin_permute_scales(
bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size
)
sort_indices = torch.empty(0, dtype=torch.int, device=device)
g_idx = torch.empty(0, dtype=torch.int, device=device)
w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0],
bt.w_ref.shape[1], bt.wtype.size_bits)
w_q = ops.gptq_marlin_repack(
bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
)
if bt.a.dtype.is_floating_point:
assert bt.w_ch_s is None
assert bt.w_tok_s is None
assert bt.group_size is not None
fn = lambda: ops.gptq_marlin_gemm(a=bt.a,
fn = lambda: ops.gptq_marlin_gemm(
a=bt.a,
b_q_weight=w_q,
b_scales=w_s,
b_zeros=w_zp,
@ -213,7 +244,8 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
size_n=bt.w_ref.shape[1],
size_k=bt.w_ref.shape[0],
is_k_full=True,
is_zp_float=False)
is_zp_float=False,
)
else:
assert bt.a.dtype == torch.int8
assert bt.wtype == scalar_types.uint4b8
@ -221,18 +253,15 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
if bt.w_ch_s is not None:
s_ch = bt.w_ch_s.to(torch.float32)
else:
s_ch = torch.ones(bt.w_ref.shape[1],
dtype=torch.float32,
device=device)
s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device)
if bt.w_tok_s is not None:
s_tok = bt.w_tok_s.to(torch.float32)
else:
s_tok = torch.ones(bt.a.shape[0],
dtype=torch.float32,
device=device)
s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device)
fn = lambda: ops.marlin_qqq_gemm(a=bt.a,
fn = lambda: ops.marlin_qqq_gemm(
a=bt.a,
b_q_weight=w_q,
s_group=w_s,
s_tok=s_tok,
@ -240,17 +269,19 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
workspace=workspace.scratch,
size_m=bt.a.shape[0],
size_n=bt.w_ref.shape[1],
size_k=bt.w_ref.shape[0])
size_k=bt.w_ref.shape[0],
)
return fn
def machete_create_bench_fn(bt: BenchmarkTensors,
out_type=torch.dtype,
schedule=None) -> Callable:
def machete_create_bench_fn(
bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
) -> Callable:
w_q = bt.w_q.t().contiguous().t() # make col major
w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype,
None if bt.w_g_s is None else bt.w_g_s.dtype)
w_q = ops.machete_prepack_B(
w_q, bt.a.dtype, bt.wtype, None if bt.w_g_s is None else bt.w_g_s.dtype
)
w_g_zp = bt.w_g_zp
if w_g_zp is not None:
@ -275,26 +306,24 @@ def machete_create_bench_fn(bt: BenchmarkTensors,
# bench
def bench_fns(label: str, sub_label: str, description: str,
fns: list[Callable]):
def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]):
min_run_time = 1 if not NVTX_PROFILE else 0.1
res = TBenchmark.Timer(
stmt="""
for fn in fns:
fn()
""",
globals={
"fns": fns
},
globals={"fns": fns},
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
if NVTX_PROFILE:
with nvtx.annotate("mm-bench"), nvtx.annotate(
f"{label}|{sub_label}|{description}"):
with (
nvtx.annotate("mm-bench"),
nvtx.annotate(f"{label}|{sub_label}|{description}"),
):
fns[0]()
return res
@ -304,19 +333,20 @@ _SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
def bench(types: TypeConfig,
def bench(
types: TypeConfig,
group_size: int,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
sweep_schedules: bool = True) -> list[TMeasurement]:
sweep_schedules: bool = True,
) -> list[TMeasurement]:
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
sub_label += f", L={len(benchmark_tensors)}"
name_type_string = f"W{types.weight_type}"+\
f"-A{terse_type_name(types.act_type)}"
name_type_string = f"W{types.weight_type}" + f"-A{terse_type_name(types.act_type)}"
if types.group_scale_type is not None:
name_type_string += f"-GS{terse_type_name(types.group_scale_type)}"
if types.group_zero_type is not None:
@ -332,31 +362,45 @@ def bench(types: TypeConfig,
# pytorch impl
timers.append(
bench_fns(
label, sub_label, "torch.matmul (fp16)",
[torch_matmul_f16_create_bench_fn(bt)
for bt in benchmark_tensors]))
label,
sub_label,
"torch.matmul (fp16)",
[torch_matmul_f16_create_bench_fn(bt) for bt in benchmark_tensors],
)
)
if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
timers.append(
bench_fns(
label, sub_label,
f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [
cutlass_scaled_mm_create_bench_fn(bt)
for bt in benchmark_tensors
]))
label,
sub_label,
f"cutlass_scaled_mm ({terse_type_name(types.act_type)})",
[cutlass_scaled_mm_create_bench_fn(bt) for bt in benchmark_tensors],
)
)
if types.act_type != torch.float8_e4m3fn:
timers.append(
bench_fns(label, sub_label, f"marlin ({name_type_string})",
[marlin_create_bench_fn(bt)
for bt in benchmark_tensors]))
bench_fns(
label,
sub_label,
f"marlin ({name_type_string})",
[marlin_create_bench_fn(bt) for bt in benchmark_tensors],
)
)
# machete
timers.append(
bench_fns(label, sub_label, f"machete ({name_type_string})", [
bench_fns(
label,
sub_label,
f"machete ({name_type_string})",
[
machete_create_bench_fn(bt, out_type=types.output_type)
for bt in benchmark_tensors
]))
],
)
)
if sweep_schedules:
global _SWEEP_SCHEDULES_RESULTS
@ -371,7 +415,8 @@ def bench(types: TypeConfig,
group_zeros_type=types.group_zero_type,
token_scales_type=types.token_scale_type,
channel_scales_type=types.channel_scale_type,
out_type=types.output_type)
out_type=types.output_type,
)
if schedules is None or len(schedules) == 0:
raise ValueError("No schedules found to sweep")
@ -383,11 +428,17 @@ def bench(types: TypeConfig,
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
continue
res = bench_fns(label, sub_label, "machete_best", [
res = bench_fns(
label,
sub_label,
"machete_best",
[
machete_create_bench_fn(
bt, out_type=types.output_type, schedule=schedule)
bt, out_type=types.output_type, schedule=schedule
)
for bt in benchmark_tensors
])
],
)
results_row = {
"M": m,
@ -398,10 +449,8 @@ def bench(types: TypeConfig,
"median": res.median,
}
if _SWEEP_SCHEDULES_RESULTS is None:
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(
columns=results_row.keys())
_SWEEP_SCHEDULES_RESULTS.\
loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(columns=results_row.keys())
_SWEEP_SCHEDULES_RESULTS.loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
print(f" {res.median:5.5} ", schedule)
if not best or res.median < best.median:
@ -422,7 +471,8 @@ def print_timers(timers: list[TMeasurement]):
def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
types = TypeConfig(
act_type=args.act_type,
weight_type=scalar_types.uint4b8 if args.group_zero_type is None \
weight_type=scalar_types.uint4b8
if args.group_zero_type is None
else scalar_types.uint4,
output_type=args.out_type,
group_scale_type=args.group_scale_type,
@ -433,14 +483,16 @@ def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
results: list[TMeasurement] = []
for m, k, n in MKNs:
timers = bench(types,
timers = bench(
types,
args.group_size,
m,
k,
n,
f"{args.act_type}-gemm",
f"MKN=({m}x{k}x{n})",
sweep_schedules=args.sweep_schedules)
sweep_schedules=args.sweep_schedules,
)
print_timers(timers)
results.extend(timers)
@ -454,7 +506,6 @@ def make_output(
base_description: str,
timestamp=None,
):
print(f"== All Results {base_description} ====")
print_timers(data)
@ -468,8 +519,7 @@ def make_output(
def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, args.sweep_schedules, MKNs)
@ -479,8 +529,9 @@ def run_square_bench(args):
def run_range_bench(args):
m_start, k_start, n_start = (int(x) for x in args.dim_start.split(","))
m_end, k_end, n_end = (int(x) for x in args.dim_end.split(","))
m_increment, k_increment, n_increment = \
(int(x) for x in args.dim_increment.split(","))
m_increment, k_increment, n_increment = (
int(x) for x in args.dim_increment.split(",")
)
Ms = list(range(m_start, m_end + 1, m_increment))
Ks = list(range(k_start, k_end + 1, k_increment))
Ns = list(range(n_start, n_end + 1, n_increment))
@ -492,7 +543,6 @@ def run_range_bench(args):
def run_model_bench(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
@ -535,10 +585,13 @@ def run_model_bench(args):
with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f:
args_dict = vars(args)
args_dict.pop("func")
pkl.dump({
pkl.dump(
{
"args": args_dict,
"results": all_results,
}, f)
},
f,
)
if __name__ == "__main__":
@ -554,7 +607,6 @@ if __name__ == "__main__":
}[dt]
class ToTorchDtype(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, to_torch_dtype(values))
@ -580,32 +632,32 @@ Benchmark Machete GEMM.
"--act-type",
action=ToTorchDtype,
required=True,
choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'],
choices=["bfloat16", "float16", "int8", "float8_e4m3fn"],
)
parser.add_argument(
"--group-scale-type",
action=ToTorchDtype,
choices=['bfloat16', 'float16'],
choices=["bfloat16", "float16"],
)
parser.add_argument(
"--group-zero-type",
type=to_torch_dtype,
choices=['bfloat16', 'float16'],
choices=["bfloat16", "float16"],
)
parser.add_argument(
"--channel-scale-type",
action=ToTorchDtype,
choices=['float'],
choices=["float"],
)
parser.add_argument(
"--token-scale-type",
action=ToTorchDtype,
choices=['float'],
choices=["float"],
)
parser.add_argument(
"--out-type",
action=ToTorchDtype,
choices=['bfloat16', 'float16'],
choices=["bfloat16", "float16"],
)
parser.add_argument(
"--group-size",
@ -618,9 +670,11 @@ Benchmark Machete GEMM.
action="store_true",
help="Run a sweep over all supported schedules",
)
parser.add_argument("--sweep-csv-out",
parser.add_argument(
"--sweep-csv-out",
help="CSV to store sweep results",
default="sch_sweep_results.csv")
default="sch_sweep_results.csv",
)
subparsers = parser.add_subparsers(dest="cmd", required=True)
square_parser = subparsers.add_parser("square_bench")
@ -634,17 +688,20 @@ Benchmark Machete GEMM.
"--dim-start",
type=str,
required=True,
help="Start value for M,K,N as common separated list")
help="Start value for M,K,N as common separated list",
)
range_parser.add_argument(
"--dim-end",
type=str,
required=True,
help="End value (inclusive) for M,K,N as common separated list")
help="End value (inclusive) for M,K,N as common separated list",
)
range_parser.add_argument(
"--dim-increment",
type=str,
required=True,
help="Increment value for M,K,N as common separated list")
help="Increment value for M,K,N as common separated list",
)
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
@ -655,14 +712,12 @@ Benchmark Machete GEMM.
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
model_parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
model_parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
model_parser.add_argument(
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
)
model_parser.add_argument(
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
)
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()

View File

@ -6,19 +6,34 @@ from benchmark_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
GPTQ_MARLIN_24_MAX_PARALLEL,
GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
)
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES)
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
ALLSPARK_SUPPORTED_QUANT_TYPES,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
GPTQ_MARLIN_MAX_PARALLEL,
GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES,
query_marlin_supported_quant_types,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, marlin_quantize)
MarlinWorkspace,
marlin_quantize,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize)
marlin_24_quantize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
gptq_pack,
gptq_quantize_weights,
quantize_weights,
sort_weights,
)
from vllm.scalar_type import ScalarType
from vllm.utils import FlexibleArgumentParser
@ -29,22 +44,29 @@ ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
def bench_run(results: list[benchmark.Measurement], model: str,
act_order: bool, is_k_full: bool, quant_type: ScalarType,
group_size: int, size_m: int, size_k: int, size_n: int):
def bench_run(
results: list[benchmark.Measurement],
model: str,
act_order: bool,
is_k_full: bool,
quant_type: ScalarType,
group_size: int,
size_m: int,
size_k: int,
size_n: int,
):
label = "Quant Matmul"
sub_label = ("{}, act={} k_full={}, q={}, g={}, "
"MKN=({}x{}x{})".format(model, act_order, is_k_full,
str(quant_type), group_size, size_m,
size_k, size_n))
sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format(
model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n
)
print(f"Testing: {sub_label}")
a = torch.randn(size_m, size_k).to(torch.half).cuda()
b = torch.rand(size_k, size_n).to(torch.half).cuda()
a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda()
# Marlin quant
(
@ -57,14 +79,16 @@ def bench_run(results: list[benchmark.Measurement], model: str,
) = marlin_quantize(b, quant_type, group_size, act_order)
# Marlin_24 quant
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b, quant_type, group_size)
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
marlin_24_quantize(b, quant_type, group_size)
)
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
# GPTQ quant
(w_ref, q_w, s, g_idx,
rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order)
(w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
b, quant_type, group_size, act_order
)
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx"
@ -74,32 +98,37 @@ def bench_run(results: list[benchmark.Measurement], model: str,
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
# Prepare
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
marlin_workspace = MarlinWorkspace(
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
)
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)
marlin_24_workspace = MarlinWorkspace(
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
)
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
# AllSpark W8A16 quant
as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
and group_size == -1 and not act_order and is_k_full)
as_supported_case = (
quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
and group_size == -1
and not act_order
and is_k_full
)
if as_supported_case:
properties = torch.cuda.get_device_properties(b.device.index)
sm_count = properties.multi_processor_count
sm_version = properties.major * 10 + properties.minor
supported_arch = (sm_version >= 80 and sm_version < 90)
supported_arch = sm_version >= 80 and sm_version < 90
as_supported_case = as_supported_case and supported_arch
if supported_arch:
has_zp = False
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size,
has_zp)
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
qw = qw.to(torch.uint8)
qw_reorder, s_reorder, zp_reorder = \
ops.allspark_repack_weight(
qw, s, zp, has_zp)
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
qw, s, zp, has_zp
)
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
globals = {
@ -136,8 +165,7 @@ def bench_run(results: list[benchmark.Measurement], model: str,
"zp_reorder": zp_reorder if as_supported_case else None,
"sm_count": sm_count if as_supported_case else None,
"sm_version": sm_version if as_supported_case else None,
"CUBLAS_M_THRESHOLD":
CUBLAS_M_THRESHOLD if as_supported_case else None,
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None,
# Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
@ -158,60 +186,63 @@ def bench_run(results: list[benchmark.Measurement], model: str,
label=label,
sub_label=sub_label,
description="pytorch_gemm",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm_fp16",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
if (
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
):
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_24_gemm",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
results.append(
benchmark.Timer(
stmt=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_repack",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
if as_supported_case:
results.append(
benchmark.Timer(
stmt=
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="allspark_w8a16_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time))
).blocked_autorange(min_run_time=min_run_time)
)
def main(args):
@ -233,37 +264,50 @@ def main(args):
continue
for act_order in ACT_ORDER_OPTS:
if len(args.limit_act_order
) > 0 and act_order not in args.limit_act_order:
if (
len(args.limit_act_order) > 0
and act_order not in args.limit_act_order
):
continue
for is_k_full in K_FULL_OPTS:
if len(args.limit_k_full
) > 0 and is_k_full not in args.limit_k_full:
if (
len(args.limit_k_full) > 0
and is_k_full not in args.limit_k_full
):
continue
for quant_type in query_marlin_supported_quant_types(
False):
if len(args.limit_num_bits) > 0 and \
quant_type.size_bits not in args.limit_num_bits:
for quant_type in query_marlin_supported_quant_types(False):
if (
len(args.limit_num_bits) > 0
and quant_type.size_bits not in args.limit_num_bits
):
continue
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
if len(
args.limit_group_size
) > 0 and group_size not in args.limit_group_size:
if (
len(args.limit_group_size) > 0
and group_size not in args.limit_group_size
):
continue
# For act_order, the group_size must be less than
# size_k
if act_order and (group_size == size_k
or group_size == -1):
if act_order and (group_size == size_k or group_size == -1):
continue
for size_m in args.batch_sizes:
bench_run(results, model, act_order, is_k_full,
quant_type, group_size, size_m,
size_k, size_n)
bench_run(
results,
model,
act_order,
is_k_full,
quant_type,
group_size,
size_m,
size_k,
size_n,
)
compare = benchmark.Compare(results)
compare.print()
@ -274,7 +318,8 @@ def main(args):
#
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches")
description="Benchmark Marlin across specified models/shapes/batches"
)
parser.add_argument(
"--models",
nargs="+",
@ -282,10 +327,9 @@ if __name__ == "__main__":
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument(
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])

View File

@ -31,7 +31,8 @@ class BenchmarkConfig(TypedDict):
num_stages: int
def benchmark_config(config: BenchmarkConfig,
def benchmark_config(
config: BenchmarkConfig,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
@ -42,45 +43,48 @@ def benchmark_config(config: BenchmarkConfig,
use_int8_w8a16: bool,
num_iters: int = 100,
block_quant_shape: List[int] = None,
use_deep_gemm: bool = False) -> float:
use_deep_gemm: bool = False,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16:
w1 = torch.randint(-127,
127, (
w1 = torch.randint(
-127,
127,
(
num_experts,
shard_intermediate_size,
hidden_size,
),
dtype=torch.int8)
w2 = torch.randint(-127,
127, (
dtype=torch.int8,
)
w2 = torch.randint(
-127,
127,
(
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8)
dtype=torch.int8,
)
else:
w1 = torch.randn(num_experts,
shard_intermediate_size,
hidden_size,
dtype=init_dtype)
w2 = torch.randn(num_experts,
hidden_size,
shard_intermediate_size // 2,
dtype=init_dtype)
gating_output = torch.randn(num_iters,
num_tokens,
num_experts,
dtype=torch.float32)
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if use_int8_w8a16:
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
dtype=torch.float32)
w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8:
if block_quant_shape:
@ -93,10 +97,14 @@ def benchmark_config(config: BenchmarkConfig,
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1),
dtype=torch.float32) * factor_for_scale
w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2),
dtype=torch.float32) * factor_for_scale
w1_scale = (
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
* factor_for_scale
)
w2_scale = (
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
* factor_for_scale
)
else:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
@ -114,10 +122,12 @@ def benchmark_config(config: BenchmarkConfig,
def run():
from vllm.model_executor.layers.fused_moe import override_config
with override_config(config):
if use_deep_gemm:
topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, False)
x, input_gating, topk, False
)
return fused_experts(
x,
w1,
@ -213,8 +223,7 @@ def get_rocm_tuning_space(use_fp16):
return param_ranges
def get_configs_compute_bound(use_fp16,
block_quant_shape) -> list[dict[str, int]]:
def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]:
configs: list[BenchmarkConfig] = []
if current_platform.is_rocm():
@ -250,20 +259,25 @@ def get_configs_compute_bound(use_fp16,
if block_quant_shape is not None and not use_fp16:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
for config in configs[:]:
if config["BLOCK_SIZE_K"] % block_k != 0 or config[
"BLOCK_SIZE_N"] % block_n != 0:
if (
config["BLOCK_SIZE_K"] % block_k != 0
or config["BLOCK_SIZE_N"] % block_n != 0
):
configs.remove(config)
return configs
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
search_space, is_fp16, topk):
def prune_rocm_search_space(
num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
):
N1, K1 = shard_intermediate_size, hidden_size
N2, K2 = hidden_size, shard_intermediate_size // 2
pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1,
search_space, is_fp16)
pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2,
search_space, is_fp16)
pruned_space_1 = prune_rocm_configs(
num_tokens * topk, N1, K1, search_space, is_fp16
)
pruned_space_2 = prune_rocm_configs(
num_tokens * topk, N2, K2, search_space, is_fp16
)
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
return search_space
@ -301,14 +315,14 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
SPLIT_K = config.get("SPLIT_K", 1)
GROUP_M = config.get("GROUP_SIZE_M")
if is_fp16:
if (matrix_instr_nonkdim > BLOCK_SIZE_M
or matrix_instr_nonkdim > BLOCK_SIZE_N):
if (
matrix_instr_nonkdim > BLOCK_SIZE_M
or matrix_instr_nonkdim > BLOCK_SIZE_N
):
continue
if (matrix_instr_nonkdim >= M
and matrix_instr_nonkdim != BLOCK_SIZE_M):
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
continue
if (matrix_instr_nonkdim >= N
and matrix_instr_nonkdim != BLOCK_SIZE_N):
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
@ -329,8 +343,10 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
LDS = (
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
@ -364,7 +380,6 @@ def merge_unique_dicts(list1, list2):
@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(seed)
@ -388,25 +403,28 @@ class BenchmarkWorker:
use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed)
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
dtype_str)
op_config = get_moe_configs(
num_experts, shard_intermediate_size // 2, dtype_str
)
if op_config is None:
config = get_default_config(num_tokens,
config = get_default_config(
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype_str,
is_marlin=False)
is_marlin=False,
)
else:
config = op_config[min(op_config.keys(),
key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config(config,
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
@ -417,7 +435,8 @@ class BenchmarkWorker:
use_int8_w8a16,
num_iters=100,
block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm)
use_deep_gemm=use_deep_gemm,
)
return config, kernel_time
def tune(
@ -438,10 +457,14 @@ class BenchmarkWorker:
best_time = float("inf")
if current_platform.is_rocm():
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = prune_rocm_search_space(num_tokens,
search_space = prune_rocm_search_space(
num_tokens,
shard_intermediate_size,
hidden_size, search_space,
is_fp16, topk)
hidden_size,
search_space,
is_fp16,
topk,
)
need_device_guard = False
if current_platform.is_rocm():
@ -449,8 +472,7 @@ class BenchmarkWorker:
if visible_device != f"{self.device_id}":
need_device_guard = True
with torch.cuda.device(
self.device_id) if need_device_guard else nullcontext():
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
@ -465,7 +487,8 @@ class BenchmarkWorker:
use_int8_w8a16,
num_iters=20,
block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm)
use_deep_gemm=use_deep_gemm,
)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
@ -481,42 +504,44 @@ class BenchmarkWorker:
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M":
config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N":
config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K":
config["BLOCK_SIZE_K"],
"GROUP_SIZE_M":
config["GROUP_SIZE_M"],
"num_warps":
config["num_warps"],
"num_stages":
config["num_stages"],
**({
"waves_per_eu": config["waves_per_eu"]
} if "waves_per_eu" in config else {}),
**({
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
} if "matrix_instr_nonkdim" in config else {}),
**({
"kpack": config["kpack"]
} if "kpack" in config else {}),
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
**(
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
),
**(
{"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
if "matrix_instr_nonkdim" in config
else {}
),
**({"kpack": config["kpack"]} if "kpack" in config else {}),
}
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
shard_intermediate_size: int, hidden_size: int, topk: int,
dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool,
block_quant_shape: List[int]) -> None:
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)
def save_configs(
configs: dict[int, BenchmarkConfig],
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
block_quant_shape: List[int],
) -> None:
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
dtype_str, block_quant_shape)
filename = get_config_file_name(
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
)
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
@ -525,18 +550,16 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
def get_weight_block_size_safety(config, default_value=None):
quantization_config = getattr(config, 'quantization_config', {})
quantization_config = getattr(config, "quantization_config", {})
if isinstance(quantization_config, dict):
return quantization_config.get('weight_block_size', default_value)
return quantization_config.get("weight_block_size", default_value)
return default_value
def main(args: argparse.Namespace):
print(args)
config = get_config(model=args.model,
trust_remote_code=args.trust_remote_code)
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
if args.model_prefix:
config = getattr(config, args.model_prefix)
config = SimpleNamespace(**config)
@ -551,14 +574,12 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif (config.architectures[0]
in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM")):
elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"):
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ("Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM"):
elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"):
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
@ -573,16 +594,35 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else getattr(
torch, config.torch_dtype)
dtype = (
torch.float16
if current_platform.is_rocm()
else getattr(torch, config.torch_dtype)
)
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config)
if args.batch_size is None:
batch_sizes = [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
else:
batch_sizes = [args.batch_size]
@ -593,7 +633,8 @@ def main(args: argparse.Namespace):
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
logger.warning(
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES.")
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
)
val = os.environ["HIP_VISIBLE_DEVICES"]
os.environ["ROCR_VISIBLE_DEVICES"] = val
del os.environ["HIP_VISIBLE_DEVICES"]
@ -620,25 +661,59 @@ def main(args: argparse.Namespace):
start = time.time()
configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space,
block_quant_shape, use_deep_gemm)
for batch_size in batch_sizes])
"tune",
[
(
batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
search_space,
block_quant_shape,
use_deep_gemm,
)
for batch_size in batch_sizes
],
)
best_configs = {
M: sort_config(config)
for M, config in zip(batch_sizes, configs)
M: sort_config(config) for M, config in zip(batch_sizes, configs)
}
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16,
block_quant_shape)
save_configs(
best_configs,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
block_quant_shape,
)
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
else:
outputs = _distribute(
"benchmark",
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm)
for batch_size in batch_sizes])
[
(
batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
block_quant_shape,
use_deep_gemm,
)
for batch_size in batch_sizes
],
)
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}")
@ -647,18 +722,15 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
parser.add_argument("--tp-size",
"-tp",
"--tensor-parallel-size",
type=int,
default=2)
parser.add_argument("--dtype",
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16"],
default="auto")
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument(
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
)
parser.add_argument(
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
)
parser.add_argument("--use-deep-gemm", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)

View File

@ -8,7 +8,9 @@ import torch
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_moe_permute, _moe_unpermute_and_reduce)
_moe_permute,
_moe_unpermute_and_reduce,
)
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
@ -27,7 +29,8 @@ class BenchmarkConfig(TypedDict):
num_stages: int
def benchmark_permute(num_tokens: int,
def benchmark_permute(
num_tokens: int,
num_experts: int,
hidden_size: int,
topk: int,
@ -35,7 +38,8 @@ def benchmark_permute(num_tokens: int,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
use_customized_permute: bool = False) -> float:
use_customized_permute: bool = False,
) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
# output_hidden_states = torch.empty_like(hidden_states)
@ -46,22 +50,20 @@ def benchmark_permute(num_tokens: int,
align_block_size = None
qhidden_states = hidden_states
gating_output = torch.randn(num_iters,
num_tokens,
num_experts,
dtype=torch.float32)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_weights, topk_ids, token_expert_indices = fused_topk(
qhidden_states, input_gating, topk, False)
qhidden_states, input_gating, topk, False
)
def prepare(i: int):
input_gating.copy_(gating_output[i])
def run():
if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx,
m_indices) = moe_permute(
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
moe_permute(
qhidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
@ -72,10 +74,17 @@ def benchmark_permute(num_tokens: int,
expert_map=None,
align_block_size=align_block_size,
)
)
else:
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
num_experts, None, align_block_size)
(
permuted_hidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = _moe_permute(
qhidden_states, None, topk_ids, num_experts, None, align_block_size
)
# JIT compilation & warmup
run()
@ -111,7 +120,8 @@ def benchmark_permute(num_tokens: int,
return avg
def benchmark_unpermute(num_tokens: int,
def benchmark_unpermute(
num_tokens: int,
num_experts: int,
hidden_size: int,
topk: int,
@ -119,7 +129,8 @@ def benchmark_unpermute(num_tokens: int,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
use_customized_permute: bool = False) -> float:
use_customized_permute: bool = False,
) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
output_hidden_states = torch.empty_like(hidden_states)
@ -133,12 +144,13 @@ def benchmark_unpermute(num_tokens: int,
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_weights, topk_ids, token_expert_indices = fused_topk(
qhidden_states, input_gating, topk, False)
qhidden_states, input_gating, topk, False
)
def prepare():
if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx,
m_indices) = moe_permute(
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
moe_permute(
qhidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
@ -149,30 +161,57 @@ def benchmark_unpermute(num_tokens: int,
expert_map=None,
align_block_size=align_block_size,
)
)
# convert to fp16/bf16 as gemm output
return (permuted_hidden_states.to(dtype), first_token_off,
inv_perm_idx, m_indices)
return (
permuted_hidden_states.to(dtype),
first_token_off,
inv_perm_idx,
m_indices,
)
else:
(permuted_qhidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
num_experts, None, align_block_size)
(
permuted_qhidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = _moe_permute(
qhidden_states, None, topk_ids, num_experts, None, align_block_size
)
# convert to fp16/bf16 as gemm output
return (permuted_qhidden_states.to(dtype), a1q_scale,
sorted_token_ids, expert_ids, inv_perm)
return (
permuted_qhidden_states.to(dtype),
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
)
def run(input: tuple):
if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx,
m_indices) = input
moe_unpermute(permuted_hidden_states, topk_weights, topk_ids,
inv_perm_idx, first_token_off, topk, num_experts,
num_experts)
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input
moe_unpermute(
permuted_hidden_states,
topk_weights,
topk_ids,
inv_perm_idx,
first_token_off,
topk,
num_experts,
num_experts,
)
else:
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm) = input
_moe_unpermute_and_reduce(output_hidden_states,
permuted_hidden_states, inv_perm,
topk_weights)
(
permuted_hidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = input
_moe_unpermute_and_reduce(
output_hidden_states, permuted_hidden_states, inv_perm, topk_weights
)
# JIT compilation & warmup
input = prepare()
@ -209,7 +248,6 @@ def benchmark_unpermute(num_tokens: int,
@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(seed)
@ -241,7 +279,8 @@ class BenchmarkWorker:
use_fp8_w8a8,
use_int8_w8a16,
num_iters=100,
use_customized_permute=use_customized_permute)
use_customized_permute=use_customized_permute,
)
unpermute_time = benchmark_unpermute(
num_tokens,
num_experts,
@ -251,15 +290,15 @@ class BenchmarkWorker:
use_fp8_w8a8,
use_int8_w8a16,
num_iters=100,
use_customized_permute=use_customized_permute)
use_customized_permute=use_customized_permute,
)
return permute_time, unpermute_time
def get_weight_block_size_safety(config, default_value=None):
quantization_config = getattr(config, 'quantization_config', {})
quantization_config = getattr(config, "quantization_config", {})
if isinstance(quantization_config, dict):
return quantization_config.get('weight_block_size', default_value)
return quantization_config.get("weight_block_size", default_value)
return default_value
@ -267,20 +306,21 @@ def main(args: argparse.Namespace):
print(args)
config = AutoConfig.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code)
args.model, trust_remote_code=args.trust_remote_code
)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
or config.architectures[0] == "DeepseekV2ForCausalLM"):
elif (
config.architectures[0] == "DeepseekV3ForCausalLM"
or config.architectures[0] == "DeepseekV2ForCausalLM"
):
E = config.n_routed_experts
topk = config.num_experts_per_tok
elif config.architectures[0] in [
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
]:
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
E = config.num_experts
topk = config.num_experts_per_tok
@ -299,8 +339,24 @@ def main(args: argparse.Namespace):
if args.batch_size is None:
batch_sizes = [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
else:
batch_sizes = [args.batch_size]
@ -321,9 +377,21 @@ def main(args: argparse.Namespace):
return ray.get(outputs)
outputs = _distribute(
"benchmark", [(batch_size, E, hidden_size, topk, dtype, use_fp8_w8a8,
use_int8_w8a16, use_customized_permute)
for batch_size in batch_sizes])
"benchmark",
[
(
batch_size,
E,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
use_customized_permute,
)
for batch_size in batch_sizes
],
)
for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}")
@ -333,13 +401,12 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
parser.add_argument("--dtype",
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16"],
default="auto")
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument(
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
)
parser.add_argument("--use-customized-permute", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)

View File

@ -9,8 +9,11 @@ import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random)
from vllm.utils import (
STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random,
)
logger = init_logger(__name__)
@ -38,19 +41,15 @@ def main(
current_platform.seed_everything(seed)
scale = float(1.0 / (head_size**0.5))
query = torch.empty(num_seqs,
num_query_heads,
head_size,
dtype=dtype,
device=device)
query = torch.empty(
num_seqs, num_query_heads, head_size, dtype=dtype, device=device
)
query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
dtype=torch.float,
device=device)
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device)
seq_lens = [seq_len for _ in range(num_seqs)]
max_seq_len = max(seq_lens)
@ -61,24 +60,23 @@ def main(
block_tables_lst: list[list[int]] = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
]
block_tables_lst.append(block_table)
block_tables = torch.tensor(block_tables_lst,
dtype=torch.int,
device=device)
block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device)
# Create the KV cache.
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
key_caches, value_caches = create_kv_caches_with_random(
NUM_BLOCKS,
block_size,
1,
num_kv_heads,
head_size,
kv_cache_dtype,
dtype,
device=device)
device=device,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# Prepare for the paged attention kernel.
@ -86,11 +84,8 @@ def main(
if version == "v2":
if current_platform.is_rocm():
global PARTITION_SIZE
if not args.custom_paged_attn:
PARTITION_SIZE = 1024
else:
PARTITION_SIZE = PARTITION_SIZE_ROCM
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
PARTITION_SIZE = 1024 if not args.custom_paged_attn else PARTITION_SIZE_ROCM
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size),
dtype=output.dtype,
@ -110,9 +105,7 @@ def main(
start_time = time.perf_counter()
# Using default kv_scale
k_scale = v_scale = torch.tensor(1.0,
dtype=torch.float32,
device=device)
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
for _ in range(num_iters):
if version == "v1":
@ -195,30 +188,29 @@ def main(
print(f"Kernel running time: {latency * 1000000:.3f} us")
if __name__ == '__main__':
logger.warning("This script benchmarks the paged attention kernel. "
"By default this is no longer used in vLLM inference.")
if __name__ == "__main__":
logger.warning(
"This script benchmarks the paged attention kernel. "
"By default this is no longer used in vLLM inference."
)
parser = FlexibleArgumentParser(
description="Benchmark the paged attention kernel.")
parser.add_argument("--version",
type=str,
choices=["v1", "v2"],
default="v2")
parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--seq-len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size",
parser.add_argument(
"--head-size",
type=int,
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128)
default=128,
)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true")
parser.add_argument("--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="half")
parser.add_argument(
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument(
@ -228,10 +220,11 @@ if __name__ == '__main__':
default="auto",
help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
parser.add_argument("--custom-paged-attn",
action="store_true",
help="Use custom paged attention")
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)",
)
parser.add_argument(
"--custom-paged-attn", action="store_true", help="Use custom paged attention"
)
args = parser.parse_args()
print(args)

View File

@ -10,7 +10,8 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
@torch.inference_mode()
def main(num_tokens: int,
def main(
num_tokens: int,
hidden_size: int,
static_scale: bool,
quant_dtype: torch.dtype,
@ -18,7 +19,8 @@ def main(num_tokens: int,
seed: int = 0,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
num_iters: int = 100,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device("cuda")
@ -56,7 +58,7 @@ def main(num_tokens: int,
print(f"Kernel running time: {latency * 1000000:.3f} us")
if __name__ == '__main__':
if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "int8":
@ -66,32 +68,34 @@ if __name__ == '__main__':
raise ValueError(f"Unsupported dtype: {dt}")
parser = FlexibleArgumentParser(
description="Benchmark the quantization (fp8 or int8) kernel.")
description="Benchmark the quantization (fp8 or int8) kernel."
)
parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192)
parser.add_argument("--static-scale", action="store_true")
parser.add_argument("--quant-dtype",
type=str,
choices=["fp8", "int8"],
default="int8")
parser.add_argument("--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="half")
parser.add_argument(
"--quant-dtype", type=str, choices=["fp8", "int8"], default="int8"
)
parser.add_argument(
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--num-warmup-iters", type=int, default=5)
parser.add_argument("--num-iters",
parser.add_argument(
"--num-iters",
type=int,
default=100,
help="Number of benchmark iterations. "
"If --profile is set, this number is ignored")
"If --profile is set, this number is ignored",
)
args = parser.parse_args()
print(args)
main(num_tokens=args.num_tokens,
main(
num_tokens=args.num_tokens,
hidden_size=args.hidden_size,
static_scale=args.static_scale,
quant_dtype=to_torch_dtype(args.quant_dtype),
@ -99,4 +103,5 @@ if __name__ == '__main__':
seed=args.seed,
do_profile=args.profile,
num_warmup_iters=args.num_warmup_iters,
num_iters=args.num_iters)
num_iters=args.num_iters,
)

View File

@ -12,7 +12,6 @@ from vllm.triton_utils import triton
class HuggingFaceRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
@ -114,23 +113,19 @@ def rmsnorm_vllm(
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
dtype = torch.bfloat16
x = torch.randn(batch_size,
seq_len,
hidden_size,
dtype=dtype,
device="cuda")
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
output_naive = rmsnorm_naive(
x.clone(), weight,
residual.clone() if residual is not None else None)
x.clone(), weight, residual.clone() if residual is not None else None
)
output_flashinfer = rmsnorm_flashinfer(
x.clone(), weight,
residual.clone() if residual is not None else None)
x.clone(), weight, residual.clone() if residual is not None else None
)
output_vllm = rmsnorm_vllm(
x.clone(), weight,
residual.clone() if residual is not None else None)
x.clone(), weight, residual.clone() if residual is not None else None
)
if use_residual:
output_naive = output_naive[0]
@ -141,9 +136,9 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
print(f"FlashInfer output={output_flashinfer}")
print(f"vLLM output={output_vllm}")
if torch.allclose(output_naive, output_flashinfer, atol=1e-2,
rtol=1e-2) and torch.allclose(
output_naive, output_vllm, atol=1e-2, rtol=1e-2):
if torch.allclose(
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
@ -152,12 +147,10 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
head_num_range = [32, 48]
configs = list(
itertools.product(head_num_range, batch_size_range, seq_length_range))
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
def get_benchmark(use_residual):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["head_num", "batch_size", "seq_len"],
@ -167,19 +160,15 @@ def get_benchmark(use_residual):
line_names=["HuggingFace", "FlashInfer", "vLLM"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name=
f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
plot_name=f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
args={},
))
)
)
def benchmark(head_num, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_size = head_num * 128 # assuming head_dim = 128
x = torch.randn(batch_size,
seq_len,
hidden_size,
dtype=dtype,
device="cuda")
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
@ -240,9 +229,9 @@ if __name__ == "__main__":
default=4096,
help="Hidden size (2nd dimension) of the sequence",
)
parser.add_argument("--use-residual",
action="store_true",
help="Whether to use residual connection")
parser.add_argument(
"--use-residual", action="store_true", help="Whether to use residual connection"
)
parser.add_argument(
"--save-path",
type=str,
@ -253,10 +242,12 @@ if __name__ == "__main__":
args = parser.parse_args()
# Run correctness test
calculate_diff(batch_size=args.batch_size,
calculate_diff(
batch_size=args.batch_size,
seq_len=args.seq_len,
hidden_size=args.hidden_size,
use_residual=args.use_residual)
use_residual=args.use_residual,
)
# Get the benchmark function with proper use_residual setting
benchmark = get_benchmark(args.use_residual)

View File

@ -6,8 +6,7 @@ from typing import Optional
import nvtx
import torch
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
get_rope)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
@ -32,40 +31,49 @@ def benchmark_rope_kernels_multi_lora(
# silulating serving 4 LoRAs
scaling_factors = [1, 2, 4, 8]
# batched RoPE can take multiple scaling factors
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, {
"rope_type": "linear",
"factor": tuple(scaling_factors)
})
batched_rope = get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
{"rope_type": "linear", "factor": tuple(scaling_factors)},
)
# non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior
non_batched_ropes: list[RotaryEmbedding] = []
for scaling_factor in scaling_factors:
non_batched_ropes.append(
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
{
"rope_type": "linear",
"factor": (scaling_factor, )
}))
get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
{"rope_type": "linear", "factor": (scaling_factor,)},
)
)
positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype)
key = torch.randn_like(query)
# create query offsets for batched RoPE, we concat multiple kv cache
# together and each query needs to find the right kv cache of its type
offset_map = torch.tensor(
list(
accumulate([0] + [
accumulate(
[0]
+ [
max_position * scaling_factor * 2
for scaling_factor in scaling_factors[:-1]
])))
query_types = torch.randint(0,
len(scaling_factors), (batch_size, seq_len),
device=device)
]
)
)
)
query_types = torch.randint(
0, len(scaling_factors), (batch_size, seq_len), device=device
)
# map query types to offsets
query_offsets = offset_map[query_types]
# the kernel takes flattened offsets
@ -86,27 +94,28 @@ def benchmark_rope_kernels_multi_lora(
torch.cuda.synchronize()
if __name__ == '__main__':
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark the rotary embedding kernels.")
description="Benchmark the rotary embedding kernels."
)
parser.add_argument("--is-neox-style", type=bool, default=True)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size",
parser.add_argument(
"--head-size",
type=int,
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128)
default=128,
)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype",
type=str,
choices=["bfloat16", "float"],
default="float")
parser.add_argument(
"--dtype", type=str, choices=["bfloat16", "float"], default="float"
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device",
type=str,
choices=["cuda:0", "cuda:1"],
default="cuda:0")
parser.add_argument(
"--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
)
args = parser.parse_args()
print(args)

View File

@ -14,14 +14,16 @@ import tqdm
import triton
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_w8a8_block_fp8_matmul)
_w8a8_block_fp8_matmul,
)
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
mp.set_start_method("spawn", force=True)
assert current_platform.is_cuda(
), "Only support tune w8a8 block fp8 kernel on CUDA device."
assert current_platform.is_cuda(), (
"Only support tune w8a8 block fp8 kernel on CUDA device."
)
DTYPE_MAP = {
"float32": torch.float32,
@ -71,18 +73,18 @@ def w8a8_block_matmul(
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
if A.dtype == torch.float8_e4m3fn:
kernel = _w8a8_block_fp8_matmul
else:
raise RuntimeError(
"Currently, only support tune w8a8 block fp8 kernel.")
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
kernel[grid](
A,
@ -119,14 +121,16 @@ def get_configs_compute_bound():
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append({
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
})
}
)
return configs
@ -165,15 +169,9 @@ def get_weight_shapes(tp_size):
return weight_shapes
def benchmark_config(A,
B,
As,
Bs,
block_size,
config,
out_dtype=torch.float16,
num_iters=10):
def benchmark_config(
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
):
def run():
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
@ -206,26 +204,26 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 *
fp8_max)
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 *
fp8_max)
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
else:
raise RuntimeError(
"Currently, only support tune w8a8 block fp8 kernel.")
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
As = torch.rand(M, k_tiles, dtype=torch.float32,
device="cuda") * factor_for_scale
Bs = (torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") *
factor_for_scale)
As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
Bs = (
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
* factor_for_scale
)
best_config = None
best_time = float("inf")
@ -267,7 +265,8 @@ def save_configs(
device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = (
f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,"
f"block_shape=[{block_n},{block_k}].json")
f"block_shape=[{block_n},{block_k}].json"
)
config_file_path = os.path.join(save_path, json_file_name)
print(f"Writing best config to {config_file_path}...")
@ -295,8 +294,7 @@ def tune_on_gpu(args_dict):
search_space = get_configs_compute_bound()
search_space = [
config for config in search_space
if block_k % config["BLOCK_SIZE_K"] == 0
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
]
start = time.time()
@ -312,15 +310,11 @@ def tune_on_gpu(args_dict):
out_dtype,
search_space,
input_type,
) for batch_size in tqdm(batch_sizes,
desc=f"GPU {gpu_id} - Batch sizes")
)
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
]
best_configs = {
M: config
for M, config in zip(batch_sizes, benchmark_results)
}
save_configs(N, K, block_n, block_k, best_configs, save_path,
input_type)
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
end = time.time()
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
@ -376,13 +370,14 @@ def main(args):
process_args = []
for gpu_id in range(num_gpus):
process_args.append({
process_args.append(
{
"gpu_id": gpu_id,
"batch_sizes": batches_per_gpu[gpu_id],
"weight_shapes":
weight_shapes, # Each GPU processes all weight shapes
"weight_shapes": weight_shapes, # Each GPU processes all weight shapes
"args": args,
})
}
)
ctx = mp.get_context("spawn")
with ctx.Pool(num_gpus) as pool:
@ -398,13 +393,11 @@ Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
Then copy to model_executor/layers/quantization/utils/configs
""",
formatter_class=argparse.RawTextHelpFormatter)
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--tp-size", "-tp", type=int, default=8)
parser.add_argument("--input-type",
type=str,
choices=["fp8"],
default="fp8")
parser.add_argument("--input-type", type=str, choices=["fp8"], default="fp8")
parser.add_argument(
"--out-dtype",
type=str,

View File

@ -11,7 +11,9 @@ from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
# Import vLLM functions
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)
from vllm.triton_utils import triton

View File

@ -14,13 +14,14 @@ from vllm.utils import FlexibleArgumentParser
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('filename', type=str)
description="Benchmark the latency of processing a single batch of "
"requests till completion."
)
parser.add_argument("filename", type=str)
args = parser.parse_args()
with open(args.filename, 'rb') as f:
with open(args.filename, "rb") as f:
data = pickle.load(f)
raw_results: list[TMeasurement] = data["results"]
@ -38,11 +39,7 @@ if __name__ == "__main__":
raise Exception("MKN not found")
kernel = v.task_spec.description
results[KN].append({
"kernel": kernel,
"batch_size": M,
"median": v.median
})
results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})
rows = int(math.ceil(len(results) / 2))
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
@ -50,14 +47,16 @@ if __name__ == "__main__":
for axs_idx, (shape, data) in enumerate(results.items()):
plt.sca(axs[axs_idx])
df = pd.DataFrame(data)
sns.lineplot(data=df,
sns.lineplot(
data=df,
x="batch_size",
y="median",
hue="kernel",
style="kernel",
markers=True,
dashes=False,
palette="Dark2")
palette="Dark2",
)
plt.title(f"Shape: {shape}")
plt.ylabel("time (median, s)")
plt.tight_layout()

View File

@ -23,6 +23,7 @@ class ArgPool:
For every invocation during a benchmarking run, it will choose a
different value from the list.
"""
values: Iterable[Any]
def __getitem__(self, index):
@ -30,9 +31,7 @@ class ArgPool:
class Bench:
class ArgsIterator:
def __init__(self, args_list, kwargs_list):
assert len(args_list) == len(kwargs_list)
self.args_list = args_list
@ -53,10 +52,16 @@ class Bench:
def n_args(self):
return self.n
def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams],
label: str, sub_label: str, description: str, fn: Callable,
*args, **kwargs):
def __init__(
self,
cuda_graph_params: Optional[CudaGraphBenchParams],
label: str,
sub_label: str,
description: str,
fn: Callable,
*args,
**kwargs,
):
self.cuda_graph_params = cuda_graph_params
self.use_cuda_graph = self.cuda_graph_params is not None
self.label = label
@ -67,10 +72,8 @@ class Bench:
# Process args
self._args = args
self._kwargs = kwargs
self.args_list, self.kwargs_list = self.collapse_argpool(
*args, **kwargs)
self.args_iterator = self.ArgsIterator(self.args_list,
self.kwargs_list)
self.args_list, self.kwargs_list = self.collapse_argpool(*args, **kwargs)
self.args_iterator = self.ArgsIterator(self.args_list, self.kwargs_list)
# Cudagraph runner
self.g = None
@ -100,16 +103,13 @@ class Bench:
for i in range(argpool_size):
# collapse args; Just pick the ith value
args_list[i] = tuple([
arg[i] if isinstance(arg, ArgPool) else arg
for arg in args_list[i]
])
args_list[i] = tuple(
[arg[i] if isinstance(arg, ArgPool) else arg for arg in args_list[i]]
)
# collapse kwargs
kwargs_i = kwargs_list[i]
arg_pool_keys = [
k for k, v in kwargs_i.items() if isinstance(v, ArgPool)
]
arg_pool_keys = [k for k, v in kwargs_i.items() if isinstance(v, ArgPool)]
for k in arg_pool_keys:
# again just pick the ith value
kwargs_i[k] = kwargs_i[k][i]
@ -142,7 +142,7 @@ class Bench:
def run_cudagrah(self) -> TMeasurement:
assert self.use_cuda_graph
globals = {'g': self.g}
globals = {"g": self.g}
return TBenchmark.Timer(
stmt="g.replay()",
@ -162,15 +162,15 @@ class Bench:
has_arg_pool = self.args_iterator.n_args > 1
if has_arg_pool:
setup = '''
setup = """
args_iterator.reset()
args_it = args_iterator.__next__()
'''
stmt = '''
"""
stmt = """
args, kwargs = next(args_it)
fn(*args, **kwargs)
'''
globals = {'fn': self.fn, 'args_iterator': self.args_iterator}
"""
globals = {"fn": self.fn, "args_iterator": self.args_iterator}
else:
# no arg pool. Just use the args and kwargs directly
self.args_iterator.reset()
@ -178,10 +178,10 @@ class Bench:
args, kwargs = next(args_it)
setup = ""
stmt = '''
stmt = """
fn(*args, **kwargs)
'''
globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs}
"""
globals = {"fn": self.fn, "args": args, "kwargs": kwargs}
return TBenchmark.Timer(
stmt=stmt,

View File

@ -7,9 +7,8 @@ from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser
# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
] * 1000
LONG_PROMPT = ' '.join(LONG_PROMPT)
LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000
LONG_PROMPT = " ".join(LONG_PROMPT)
def main(args):
@ -30,32 +29,35 @@ def main(args):
print("------start generating------")
for i in range(3):
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
globals(), locals())
profiler.runctx(
"llm.generate(LONG_PROMPT, sampling_params)", globals(), locals()
)
# analyze the runtime of hashing function
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.sort_stats("cumulative")
total_time = 0
total_calls = 0
for func in stats.stats:
if 'hash_of_block' in func[2]:
if "hash_of_block" in func[2]:
total_time = stats.stats[func][3]
total_calls = stats.stats[func][0]
percentage = (total_time / stats.total_tt) * 100
print(f"Hashing took {total_time:.2f} seconds,"
f"{percentage:.2f}% of the total runtime.")
print(
f"Hashing took {total_time:.2f} seconds,{percentage:.2f}% of the total runtime."
)
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Benchmark the performance of hashing function in'
'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
description="Benchmark the performance of hashing function in"
"automatic prefix caching."
)
parser.add_argument("--model", type=str, default="lmsys/longchat-7b-16k")
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--output-len", type=int, default=10)
parser.add_argument(
"--enable-prefix-caching", action="store_true", help="enable prefix caching"
)
args = parser.parse_args()
main(args)

54
benchmarks/pyproject.toml Normal file
View File

@ -0,0 +1,54 @@
# This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the
# following differences:
# - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed
[tool.ruff]
line-length = 88
exclude = [
# External file, leaving license intact
"examples/other/fp8/quantizer/quantize.py",
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
]
[tool.ruff.lint.per-file-ignores]
"vllm/third_party/**" = ["ALL"]
"vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"]
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
# flake8-logging-format
"G",
]
ignore = [
# star imports
"F405", "F403",
# lambda expression assignment
"E731",
# Loop control variable not used within loop body
"B007",
# f-string format
"UP032",
# Can remove once 3.10+ is the minimum Python version
"UP007",
]
[tool.ruff.lint.isort]
known-first-party = ["vllm"]
[tool.ruff.format]
docstring-code-format = true

View File

@ -54,6 +54,7 @@ include = ["vllm*"]
[tool.yapfignore]
ignore_patterns = [
".buildkite/**",
"benchmarks/**",
"build/**",
]
@ -155,6 +156,10 @@ ignore-words-list = "dout, te, indicies, subtile, ElementE"
skip = "tests/models/fixtures/*,tests/prompts/*,benchmarks/sonnet.txt,tests/lora/data/*,build/*,vllm/third_party/*"
[tool.isort]
skip_glob = [
".buildkite/*",
"benchmarks/*",
]
use_parentheses = true
skip_gitignore = true