Compare commits
1 Commits
v0.8.5.pos
...
low_latenc
| Author | SHA1 | Date | |
|---|---|---|---|
| 79acf80471 |
@ -527,7 +527,7 @@ def get_weight_block_size_safety(config, default_value=None):
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
block_quant_shape = None
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.model, trust_remote_code=args.trust_remote_code)
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
@ -546,9 +546,8 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in [
|
||||
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
|
||||
]:
|
||||
block_quant_shape = get_weight_block_size_safety(config)
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
@ -566,7 +565,6 @@ def main(args: argparse.Namespace):
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
block_quant_shape = get_weight_block_size_safety(config)
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
|
||||
@ -1,58 +0,0 @@
|
||||
# Security Guide
|
||||
|
||||
## Inter-Node Communication
|
||||
|
||||
All communications between nodes in a multi-node vLLM deployment are **insecure by default** and must be protected by placing the nodes on an isolated network. This includes:
|
||||
|
||||
1. PyTorch Distributed communications
|
||||
2. KV cache transfer communications
|
||||
3. Tensor, Pipeline, and Data parallel communications
|
||||
|
||||
### Configuration Options for Inter-Node Communications
|
||||
|
||||
The following options control inter-node communications in vLLM:
|
||||
|
||||
1. **Environment Variables:**
|
||||
- `VLLM_HOST_IP`: Sets the IP address for vLLM processes to communicate on
|
||||
|
||||
2. **KV Cache Transfer Configuration:**
|
||||
- `--kv-ip`: The IP address for KV cache transfer communications (default: 127.0.0.1)
|
||||
- `--kv-port`: The port for KV cache transfer communications (default: 14579)
|
||||
|
||||
3. **Data Parallel Configuration:**
|
||||
- `data_parallel_master_ip`: IP of the data parallel master (default: 127.0.0.1)
|
||||
- `data_parallel_master_port`: Port of the data parallel master (default: 29500)
|
||||
|
||||
### Notes on PyTorch Distributed
|
||||
|
||||
vLLM uses PyTorch's distributed features for some inter-node communication. For
|
||||
detailed information about PyTorch Distributed security considerations, please
|
||||
refer to the [PyTorch Security
|
||||
Guide](https://github.com/pytorch/pytorch/security/policy#using-distributed-features).
|
||||
|
||||
Key points from the PyTorch security guide:
|
||||
- PyTorch Distributed features are intended for internal communication only
|
||||
- They are not built for use in untrusted environments or networks
|
||||
- No authorization protocol is included for performance reasons
|
||||
- Messages are sent unencrypted
|
||||
- Connections are accepted from anywhere without checks
|
||||
|
||||
### Security Recommendations
|
||||
|
||||
1. **Network Isolation:**
|
||||
- Deploy vLLM nodes on a dedicated, isolated network
|
||||
- Use network segmentation to prevent unauthorized access
|
||||
- Implement appropriate firewall rules
|
||||
|
||||
2. **Configuration Best Practices:**
|
||||
- Always set `VLLM_HOST_IP` to a specific IP address rather than using defaults
|
||||
- Configure firewalls to only allow necessary ports between nodes
|
||||
|
||||
3. **Access Control:**
|
||||
- Restrict physical and network access to the deployment environment
|
||||
- Implement proper authentication and authorization for management interfaces
|
||||
- Follow the principle of least privilege for all system components
|
||||
|
||||
## Reporting Security Vulnerabilities
|
||||
|
||||
If you believe you have found a security vulnerability in vLLM, please report it following the project's security policy. For more information on how to report security issues and the project's security policy, please see the [vLLM Security Policy](https://github.com/vllm-project/vllm/blob/main/SECURITY.md).
|
||||
@ -132,7 +132,6 @@ serving/integrations/index
|
||||
:caption: Deployment
|
||||
:maxdepth: 1
|
||||
|
||||
deployment/security
|
||||
deployment/docker
|
||||
deployment/k8s
|
||||
deployment/nginx
|
||||
|
||||
@ -77,10 +77,6 @@ bash run_cluster.sh \
|
||||
|
||||
Then you get a ray cluster of **containers**. Note that you need to keep the shells running these commands alive to hold the cluster. Any shell disconnect will terminate the cluster. In addition, please note that the argument `ip_of_head_node` should be the IP address of the head node, which is accessible by all the worker nodes. The IP addresses of each worker node should be specified in the `VLLM_HOST_IP` environment variable, and should be different for each worker node. Please check the network configuration of your cluster to make sure the nodes can communicate with each other through the specified IP addresses.
|
||||
|
||||
:::{warning}
|
||||
It is considered best practice to set `VLLM_HOST_IP` to an address on a private network segment for the vLLM cluster. The traffic sent here is not encrypted. The endpoints are also exchanging data in a format that could be exploited to execute arbitrary code should a malicious party gain access to the network. Please ensure that this network is not reachable by any untrusted parties.
|
||||
:::
|
||||
|
||||
:::{warning}
|
||||
Since this is a ray cluster of **containers**, all the following commands should be executed in the **containers**, otherwise you are executing the commands on the host machine, which is not connected to the ray cluster. To enter the container, you can use `docker exec -it node /bin/bash`.
|
||||
:::
|
||||
|
||||
@ -10,12 +10,12 @@ prompts = [
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=10)
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
llm = LLM(model="facebook/opt-125m", disable_cascade_attn=True)
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
|
||||
@ -15,8 +15,7 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "vllm"
|
||||
authors = [{name = "vLLM Team"}]
|
||||
license = "Apache-2.0"
|
||||
license-files = ["LICENSE"]
|
||||
license = { "file"= "LICENSE" }
|
||||
readme = "README.md"
|
||||
description = "A high-throughput and memory-efficient inference and serving engine for LLMs"
|
||||
classifiers = [
|
||||
@ -24,6 +23,7 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Information Technology",
|
||||
"Intended Audience :: Science/Research",
|
||||
|
||||
@ -20,11 +20,15 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
|
||||
("facebook/opt-125m", {}),
|
||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
||||
"dtype": torch.float16,
|
||||
"quantization": "compressed-tensors"
|
||||
}),
|
||||
("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", {
|
||||
"dtype": torch.float16,
|
||||
"quantization": "compressed-tensors"
|
||||
}),
|
||||
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {
|
||||
"quantization": "compressed-tensors"
|
||||
}),
|
||||
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
|
||||
("meta-llama/Llama-3.2-1B-Instruct", {}),
|
||||
]
|
||||
|
||||
|
||||
@ -1,118 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
from argparse import ArgumentError, ArgumentTypeError
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import PoolerConfig, config
|
||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
||||
get_type, is_not_builtin, is_type,
|
||||
nullable_kvs, optional_type)
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.engine.arg_utils import EngineArgs, nullable_kvs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type", "value", "expected"), [
|
||||
(int, "42", 42),
|
||||
(int, "None", None),
|
||||
(float, "3.14", 3.14),
|
||||
(float, "None", None),
|
||||
(str, "Hello World!", "Hello World!"),
|
||||
(str, "None", None),
|
||||
(json.loads, '{"foo":1,"bar":2}', {
|
||||
"foo": 1,
|
||||
"bar": 2
|
||||
}),
|
||||
(json.loads, "foo=1,bar=2", {
|
||||
"foo": 1,
|
||||
"bar": 2
|
||||
}),
|
||||
(json.loads, "None", None),
|
||||
])
|
||||
def test_optional_type(type, value, expected):
|
||||
optional_type_func = optional_type(type)
|
||||
context = nullcontext()
|
||||
if value == "foo=1,bar=2":
|
||||
context = pytest.warns(DeprecationWarning)
|
||||
with context:
|
||||
assert optional_type_func(value) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
|
||||
(int, int, True),
|
||||
(int, float, False),
|
||||
(list[int], list, True),
|
||||
(list[int], tuple, False),
|
||||
(Literal[0, 1], Literal, True),
|
||||
])
|
||||
def test_is_type(type_hint, type, expected):
|
||||
assert is_type(type_hint, type) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
|
||||
({float, int}, int, True),
|
||||
({int, tuple[int]}, int, True),
|
||||
({int, tuple[int]}, float, False),
|
||||
({str, Literal["x", "y"]}, Literal, True),
|
||||
])
|
||||
def test_contains_type(type_hints, type, expected):
|
||||
assert contains_type(type_hints, type) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
|
||||
({int, float}, int, int),
|
||||
({int, float}, str, None),
|
||||
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
|
||||
])
|
||||
def test_get_type(type_hints, type, expected):
|
||||
assert get_type(type_hints, type) == expected
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DummyConfigClass:
|
||||
regular_bool: bool = True
|
||||
"""Regular bool with default True"""
|
||||
optional_bool: Optional[bool] = None
|
||||
"""Optional bool with default None"""
|
||||
optional_literal: Optional[Literal["x", "y"]] = None
|
||||
"""Optional literal with default None"""
|
||||
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
|
||||
"""Tuple with default (1, 2, 3)"""
|
||||
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
|
||||
"""Tuple with default (1, 2)"""
|
||||
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
|
||||
"""List with default [1, 2, 3]"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hint", "expected"), [
|
||||
(int, False),
|
||||
(DummyConfigClass, True),
|
||||
])
|
||||
def test_is_not_builtin(type_hint, expected):
|
||||
assert is_not_builtin(type_hint) == expected
|
||||
|
||||
|
||||
def test_get_kwargs():
|
||||
kwargs = get_kwargs(DummyConfigClass)
|
||||
print(kwargs)
|
||||
|
||||
# bools should not have their type set
|
||||
assert kwargs["regular_bool"].get("type") is None
|
||||
assert kwargs["optional_bool"].get("type") is None
|
||||
# optional literals should have None as a choice
|
||||
assert kwargs["optional_literal"]["choices"] == ["x", "y", "None"]
|
||||
# tuples should have the correct nargs
|
||||
assert kwargs["tuple_n"]["nargs"] == "+"
|
||||
assert kwargs["tuple_2"]["nargs"] == 2
|
||||
# lists should work
|
||||
assert kwargs["list_n"]["type"] is int
|
||||
assert kwargs["list_n"]["nargs"] == "+"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("arg", "expected"), [
|
||||
(None, dict()),
|
||||
("image=16", {
|
||||
|
||||
@ -1165,80 +1165,3 @@ def test_kv_connector_handles_preemption():
|
||||
# All memory should be freed since nothing is running.
|
||||
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
|
||||
== NUM_BLOCKS - 1
|
||||
|
||||
|
||||
def make_output(scheduler: Scheduler):
|
||||
return ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in scheduler.running],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(scheduler.running)
|
||||
},
|
||||
sampled_token_ids=[[1000]] * len(scheduler.running),
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
|
||||
|
||||
def assert_scheduler_empty(scheduler: Scheduler):
|
||||
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
||||
# Scheduler Metadata.
|
||||
assert len(scheduler.requests) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
assert len(scheduler._cached_reqs_data) == 0
|
||||
|
||||
# EncoderCacheManager.
|
||||
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert len(scheduler.kv_cache_manager.req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
|
||||
assert len(scheduler.kv_cache_manager.num_cached_block) == 0
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
assert num_free_blocks == (
|
||||
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
|
||||
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||
# value, etc will remain since we lazily evict for prefix cache.
|
||||
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
||||
assert block.ref_cnt == 0
|
||||
# assert block._block_hash is None
|
||||
# assert (
|
||||
# len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
|
||||
# ) == 0)
|
||||
|
||||
|
||||
def test_memory_leak():
|
||||
"""Test that we do not have a memory leak."""
|
||||
|
||||
scheduler = create_scheduler(enable_prefix_caching=True)
|
||||
|
||||
NUM_REQUESTS = 5
|
||||
NUM_TOKENS = 10
|
||||
MAX_TOKENS = 10
|
||||
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||
num_tokens=NUM_TOKENS,
|
||||
max_tokens=MAX_TOKENS)
|
||||
|
||||
# Add each request.
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Iterate until done.
|
||||
while True:
|
||||
scheduler_output = scheduler.schedule()
|
||||
if len(scheduler.running) == 0:
|
||||
break
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Confirm no memory leak.
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
@ -28,7 +28,6 @@ import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||
QuantizationMethods,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
@ -753,8 +752,9 @@ class ModelConfig:
|
||||
supported_quantization = QUANTIZATION_METHODS
|
||||
optimized_quantization_methods = [
|
||||
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
|
||||
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
|
||||
"quark", "nvfp4", "bitblas", "gptq_bitblas"
|
||||
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
|
||||
"compressed-tensors", "experts_int8", "quark", "nvfp4", "bitblas",
|
||||
"gptq_bitblas"
|
||||
]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
@ -764,47 +764,13 @@ class ModelConfig:
|
||||
|
||||
if quant_cfg is not None:
|
||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||
quant_method = quant_method.replace("compressed_tensors",
|
||||
"compressed-tensors")
|
||||
quant_cfg["quant_method"] = quant_method
|
||||
|
||||
# Quantization methods which are overrides (i.e. they have a
|
||||
# `override_quantization_method` method) must be checked in order
|
||||
# of preference (this is particularly important for GPTQ).
|
||||
overrides = [
|
||||
"marlin",
|
||||
"bitblas",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"gptq_bitblas",
|
||||
"awq_marlin",
|
||||
"ipex",
|
||||
"moe_wna16",
|
||||
]
|
||||
quantization_methods = [
|
||||
q for q in supported_quantization if q not in overrides
|
||||
]
|
||||
# Any custom overrides will be in quantization_methods so we place
|
||||
# them at the start of the list so custom overrides have preference
|
||||
# over the built in ones.
|
||||
quantization_methods = quantization_methods + overrides
|
||||
|
||||
# Detect which checkpoint is it
|
||||
for name in quantization_methods:
|
||||
for name in QUANTIZATION_METHODS:
|
||||
method = get_quantization_config(name)
|
||||
quantization_override = method.override_quantization_method(
|
||||
quant_cfg, self.quantization)
|
||||
if quantization_override is not None:
|
||||
# Raise error if the override is not custom (custom would
|
||||
# be in QUANTIZATION_METHODS but not QuantizationMethods)
|
||||
# and hasn't been added to the overrides list.
|
||||
if (name in get_args(QuantizationMethods)
|
||||
and name not in overrides):
|
||||
raise ValueError(
|
||||
f"Quantization method {name} is an override but "
|
||||
"is has not been added to the `overrides` list "
|
||||
"above. This is necessary to ensure that the "
|
||||
"overrides are checked in order of preference.")
|
||||
if quantization_override:
|
||||
quant_method = quantization_override
|
||||
self.quantization = quantization_override
|
||||
break
|
||||
|
||||
@ -241,7 +241,7 @@ class MessageQueue:
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
remote_addr_ipv6 = True
|
||||
connect_ip = f"[{connect_ip}]"
|
||||
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
socket_addr = f"tcp://*:{remote_subscribe_port}"
|
||||
self.remote_socket.bind(socket_addr)
|
||||
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
else:
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
|
||||
TypeVar, Union, cast, get_args, get_origin)
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeIs, deprecated
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import version
|
||||
@ -48,29 +48,33 @@ TypeHint = Union[type[Any], object]
|
||||
TypeHintT = Union[type[T], object]
|
||||
|
||||
|
||||
def optional_type(
|
||||
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
|
||||
|
||||
def _optional_type(val: str) -> Optional[T]:
|
||||
if val == "" or val == "None":
|
||||
return None
|
||||
try:
|
||||
if return_type is json.loads and not re.match("^{.*}$", val):
|
||||
return cast(T, nullable_kvs(val))
|
||||
return return_type(val)
|
||||
except ValueError as e:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Value {val} cannot be converted to {return_type}.") from e
|
||||
|
||||
return _optional_type
|
||||
def optional_arg(val: str, return_type: Callable[[str], T]) -> Optional[T]:
|
||||
if val == "" or val == "None":
|
||||
return None
|
||||
try:
|
||||
return return_type(val)
|
||||
except ValueError as e:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Value {val} cannot be converted to {return_type}.") from e
|
||||
|
||||
|
||||
@deprecated(
|
||||
"Passing a JSON argument as a string containing comma separated key=value "
|
||||
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
|
||||
"string instead.")
|
||||
def nullable_kvs(val: str) -> dict[str, int]:
|
||||
"""Parses a string containing comma separate key [str] to value [int]
|
||||
def optional_str(val: str) -> Optional[str]:
|
||||
return optional_arg(val, str)
|
||||
|
||||
|
||||
def optional_int(val: str) -> Optional[int]:
|
||||
return optional_arg(val, int)
|
||||
|
||||
|
||||
def optional_float(val: str) -> Optional[float]:
|
||||
return optional_arg(val, float)
|
||||
|
||||
|
||||
def nullable_kvs(val: str) -> Optional[dict[str, int]]:
|
||||
"""NOTE: This function is deprecated, args should be passed as JSON
|
||||
strings instead.
|
||||
|
||||
Parses a string containing comma separate key [str] to value [int]
|
||||
pairs into a dictionary.
|
||||
|
||||
Args:
|
||||
@ -79,7 +83,10 @@ def nullable_kvs(val: str) -> dict[str, int]:
|
||||
Returns:
|
||||
Dictionary with parsed values.
|
||||
"""
|
||||
out_dict: dict[str, int] = {}
|
||||
if len(val) == 0:
|
||||
return None
|
||||
|
||||
out_dict: Dict[str, int] = {}
|
||||
for item in val.split(","):
|
||||
kv_parts = [part.lower().strip() for part in item.split("=")]
|
||||
if len(kv_parts) != 2:
|
||||
@ -101,103 +108,15 @@ def nullable_kvs(val: str) -> dict[str, int]:
|
||||
return out_dict
|
||||
|
||||
|
||||
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
|
||||
"""Check if the type hint is a specific type."""
|
||||
return type_hint is type or get_origin(type_hint) is type
|
||||
def optional_dict(val: str) -> Optional[dict[str, int]]:
|
||||
if re.match("^{.*}$", val):
|
||||
return optional_arg(val, json.loads)
|
||||
|
||||
|
||||
def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
|
||||
"""Check if the type hints contain a specific type."""
|
||||
return any(is_type(type_hint, type) for type_hint in type_hints)
|
||||
|
||||
|
||||
def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
|
||||
"""Get the specific type from the type hints."""
|
||||
return next((th for th in type_hints if is_type(th, type)), None)
|
||||
|
||||
|
||||
def is_not_builtin(type_hint: TypeHint) -> bool:
|
||||
"""Check if the class is not a built-in type."""
|
||||
return type_hint.__module__ != "builtins"
|
||||
|
||||
|
||||
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
cls_docs = get_attr_docs(cls)
|
||||
kwargs = {}
|
||||
for field in fields(cls):
|
||||
# Get the default value of the field
|
||||
default = field.default
|
||||
if field.default_factory is not MISSING:
|
||||
default = field.default_factory()
|
||||
|
||||
# Get the help text for the field
|
||||
name = field.name
|
||||
help = cls_docs[name]
|
||||
# Escape % for argparse
|
||||
help = help.replace("%", "%%")
|
||||
|
||||
# Initialise the kwargs dictionary for the field
|
||||
kwargs[name] = {"default": default, "help": help}
|
||||
|
||||
# Get the set of possible types for the field
|
||||
type_hints: set[TypeHint] = set()
|
||||
if get_origin(field.type) is Union:
|
||||
type_hints.update(get_args(field.type))
|
||||
else:
|
||||
type_hints.add(field.type)
|
||||
|
||||
# Set other kwargs based on the type hints
|
||||
if contains_type(type_hints, bool):
|
||||
# Creates --no-<name> and --<name> flags
|
||||
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
||||
elif contains_type(type_hints, Literal):
|
||||
# Creates choices from Literal arguments
|
||||
type_hint = get_type(type_hints, Literal)
|
||||
choices = sorted(get_args(type_hint))
|
||||
kwargs[name]["choices"] = choices
|
||||
choice_type = type(choices[0])
|
||||
assert all(type(c) is choice_type for c in choices), (
|
||||
"All choices must be of the same type. "
|
||||
f"Got {choices} with types {[type(c) for c in choices]}")
|
||||
kwargs[name]["type"] = choice_type
|
||||
elif contains_type(type_hints, tuple):
|
||||
type_hint = get_type(type_hints, tuple)
|
||||
types = get_args(type_hint)
|
||||
tuple_type = types[0]
|
||||
assert all(t is tuple_type for t in types if t is not Ellipsis), (
|
||||
"All non-Ellipsis tuple elements must be of the same "
|
||||
f"type. Got {types}.")
|
||||
kwargs[name]["type"] = tuple_type
|
||||
kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
|
||||
elif contains_type(type_hints, list):
|
||||
type_hint = get_type(type_hints, list)
|
||||
types = get_args(type_hint)
|
||||
assert len(types) == 1, (
|
||||
"List type must have exactly one type. Got "
|
||||
f"{type_hint} with types {types}")
|
||||
kwargs[name]["type"] = types[0]
|
||||
kwargs[name]["nargs"] = "+"
|
||||
elif contains_type(type_hints, int):
|
||||
kwargs[name]["type"] = int
|
||||
elif contains_type(type_hints, float):
|
||||
kwargs[name]["type"] = float
|
||||
elif contains_type(type_hints, dict):
|
||||
# Dict arguments will always be optional
|
||||
kwargs[name]["type"] = optional_type(json.loads)
|
||||
elif (contains_type(type_hints, str)
|
||||
or any(is_not_builtin(th) for th in type_hints)):
|
||||
kwargs[name]["type"] = str
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type {type_hints} for argument {name}.")
|
||||
|
||||
# If None is in type_hints, make the argument optional.
|
||||
# But not if it's a bool, argparse will handle this better.
|
||||
if type(None) in type_hints and not contains_type(type_hints, bool):
|
||||
kwargs[name]["type"] = optional_type(kwargs[name]["type"])
|
||||
if kwargs[name].get("choices"):
|
||||
kwargs[name]["choices"].append("None")
|
||||
return kwargs
|
||||
logger.warning(
|
||||
"Failed to parse JSON string. Attempting to parse as "
|
||||
"comma-separated key=value pairs. This will be deprecated in a "
|
||||
"future release.")
|
||||
return nullable_kvs(val)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -360,6 +279,100 @@ class EngineArgs:
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Shared CLI arguments for vLLM engine."""
|
||||
|
||||
def is_type_in_union(cls: TypeHint, type: TypeHint) -> bool:
|
||||
"""Check if the class is a type in a union type."""
|
||||
is_union = get_origin(cls) is Union
|
||||
type_in_union = type in [get_origin(a) or a for a in get_args(cls)]
|
||||
return is_union and type_in_union
|
||||
|
||||
def get_type_from_union(cls: TypeHint, type: TypeHintT) -> TypeHintT:
|
||||
"""Get the type in a union type."""
|
||||
for arg in get_args(cls):
|
||||
if (get_origin(arg) or arg) is type:
|
||||
return arg
|
||||
raise ValueError(f"Type {type} not found in union type {cls}.")
|
||||
|
||||
def is_optional(cls: TypeHint) -> TypeIs[Union[Any, None]]:
|
||||
"""Check if the class is an optional type."""
|
||||
return is_type_in_union(cls, type(None))
|
||||
|
||||
def can_be_type(cls: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
|
||||
"""Check if the class can be of type."""
|
||||
return cls is type or get_origin(cls) is type or is_type_in_union(
|
||||
cls, type)
|
||||
|
||||
def is_custom_type(cls: TypeHint) -> bool:
|
||||
"""Check if the class is a custom type."""
|
||||
return cls.__module__ != "builtins"
|
||||
|
||||
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
cls_docs = get_attr_docs(cls)
|
||||
kwargs = {}
|
||||
for field in fields(cls):
|
||||
# Get the default value of the field
|
||||
default = field.default
|
||||
if field.default_factory is not MISSING:
|
||||
default = field.default_factory()
|
||||
|
||||
# Get the help text for the field
|
||||
name = field.name
|
||||
help = cls_docs[name]
|
||||
# Escape % for argparse
|
||||
help = help.replace("%", "%%")
|
||||
|
||||
# Initialise the kwargs dictionary for the field
|
||||
kwargs[name] = {"default": default, "help": help}
|
||||
|
||||
# Make note of if the field is optional and get the actual
|
||||
# type of the field if it is
|
||||
optional = is_optional(field.type)
|
||||
field_type = get_args(
|
||||
field.type)[0] if optional else field.type
|
||||
|
||||
# Set type, action and choices for the field depending on the
|
||||
# type of the field
|
||||
if can_be_type(field_type, bool):
|
||||
# Creates --no-<name> and --<name> flags
|
||||
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
||||
kwargs[name]["type"] = bool
|
||||
elif can_be_type(field_type, Literal):
|
||||
# Creates choices from Literal arguments
|
||||
if is_type_in_union(field_type, Literal):
|
||||
field_type = get_type_from_union(field_type, Literal)
|
||||
choices = get_args(field_type)
|
||||
kwargs[name]["choices"] = choices
|
||||
choice_type = type(choices[0])
|
||||
assert all(type(c) is choice_type for c in choices), (
|
||||
"All choices must be of the same type. "
|
||||
f"Got {choices} with types {[type(c) for c in choices]}"
|
||||
)
|
||||
kwargs[name]["type"] = choice_type
|
||||
elif can_be_type(field_type, tuple):
|
||||
if is_type_in_union(field_type, tuple):
|
||||
field_type = get_type_from_union(field_type, tuple)
|
||||
dtypes = get_args(field_type)
|
||||
dtype = dtypes[0]
|
||||
assert all(
|
||||
d is dtype for d in dtypes if d is not Ellipsis
|
||||
), ("All non-Ellipsis tuple elements must be of the same "
|
||||
f"type. Got {dtypes}.")
|
||||
kwargs[name]["type"] = dtype
|
||||
kwargs[name]["nargs"] = "+"
|
||||
elif can_be_type(field_type, int):
|
||||
kwargs[name]["type"] = optional_int if optional else int
|
||||
elif can_be_type(field_type, float):
|
||||
kwargs[name][
|
||||
"type"] = optional_float if optional else float
|
||||
elif can_be_type(field_type, dict):
|
||||
kwargs[name]["type"] = optional_dict
|
||||
elif (can_be_type(field_type, str)
|
||||
or is_custom_type(field_type)):
|
||||
kwargs[name]["type"] = optional_str if optional else str
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type {field.type} for argument {name}. ")
|
||||
return kwargs
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
@ -377,13 +390,13 @@ class EngineArgs:
|
||||
'which task to use.')
|
||||
parser.add_argument(
|
||||
'--tokenizer',
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=EngineArgs.tokenizer,
|
||||
help='Name or path of the huggingface tokenizer to use. '
|
||||
'If unspecified, model name or path will be used.')
|
||||
parser.add_argument(
|
||||
"--hf-config-path",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=EngineArgs.hf_config_path,
|
||||
help='Name or path of the huggingface config to use. '
|
||||
'If unspecified, model name or path will be used.')
|
||||
@ -395,21 +408,21 @@ class EngineArgs:
|
||||
'the input. The generated output will contain token ids.')
|
||||
parser.add_argument(
|
||||
'--revision',
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=None,
|
||||
help='The specific model version to use. It can be a branch '
|
||||
'name, a tag name, or a commit id. If unspecified, will use '
|
||||
'the default version.')
|
||||
parser.add_argument(
|
||||
'--code-revision',
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=None,
|
||||
help='The specific revision to use for the model code on '
|
||||
'Hugging Face Hub. It can be a branch name, a tag name, or a '
|
||||
'commit id. If unspecified, will use the default version.')
|
||||
parser.add_argument(
|
||||
'--tokenizer-revision',
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=None,
|
||||
help='Revision of the huggingface tokenizer to use. '
|
||||
'It can be a branch name, a tag name, or a commit id. '
|
||||
@ -500,7 +513,7 @@ class EngineArgs:
|
||||
|
||||
parser.add_argument(
|
||||
'--logits-processor-pattern',
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=None,
|
||||
help='Optional regex pattern specifying valid logits processor '
|
||||
'qualified names that can be passed with the `logits_processors` '
|
||||
@ -599,7 +612,7 @@ class EngineArgs:
|
||||
# Quantization settings.
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
choices=[*QUANTIZATION_METHODS, None],
|
||||
default=EngineArgs.quantization,
|
||||
help='Method used to quantize the weights. If '
|
||||
@ -908,7 +921,7 @@ class EngineArgs:
|
||||
'class without changing the existing functions.')
|
||||
parser.add_argument(
|
||||
"--generation-config",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default="auto",
|
||||
help="The folder path to the generation config. "
|
||||
"Defaults to 'auto', the generation config will be loaded from "
|
||||
|
||||
@ -11,7 +11,7 @@ import ssl
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union, get_args
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
validate_chat_template)
|
||||
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
||||
@ -79,7 +79,7 @@ class PromptAdapterParserAction(argparse.Action):
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument("--host",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=None,
|
||||
help="Host name.")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port number.")
|
||||
@ -108,13 +108,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
default=["*"],
|
||||
help="Allowed headers.")
|
||||
parser.add_argument("--api-key",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=None,
|
||||
help="If provided, the server will require this key "
|
||||
"to be presented in the header.")
|
||||
parser.add_argument(
|
||||
"--lora-modules",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=LoRAParserAction,
|
||||
@ -126,14 +126,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"\"base_model_name\": \"id\"}``")
|
||||
parser.add_argument(
|
||||
"--prompt-adapters",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=PromptAdapterParserAction,
|
||||
help="Prompt adapter configurations in the format name=path. "
|
||||
"Multiple adapters can be specified.")
|
||||
parser.add_argument("--chat-template",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=None,
|
||||
help="The file path to the chat template, "
|
||||
"or the template in single-line form "
|
||||
@ -151,20 +151,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
'similar to OpenAI schema. '
|
||||
'Example: ``[{"type": "text", "text": "Hello world!"}]``')
|
||||
parser.add_argument("--response-role",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"``request.add_generation_prompt=true``.")
|
||||
parser.add_argument("--ssl-keyfile",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=None,
|
||||
help="The file path to the SSL key file.")
|
||||
parser.add_argument("--ssl-certfile",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=None,
|
||||
help="The file path to the SSL cert file.")
|
||||
parser.add_argument("--ssl-ca-certs",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=None,
|
||||
help="The CA certificates file.")
|
||||
parser.add_argument(
|
||||
@ -180,13 +180,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root-path",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default=None,
|
||||
help="FastAPI root_path when app is behind a path based routing proxy."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--middleware",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
action="append",
|
||||
default=[],
|
||||
help="Additional ASGI middleware to apply to the app. "
|
||||
|
||||
@ -12,7 +12,7 @@ import torch
|
||||
from prometheus_client import start_http_server
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.logger import RequestLogger, logger
|
||||
# yapf: disable
|
||||
@ -61,7 +61,7 @@ def parse_args():
|
||||
"to the output URL.",
|
||||
)
|
||||
parser.add_argument("--response-role",
|
||||
type=optional_type(str),
|
||||
type=optional_str,
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=True`.")
|
||||
|
||||
@ -85,6 +85,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_MOE_PADDING: bool = True
|
||||
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
||||
VLLM_ENABLE_V1_ADVANCE_STEP: bool = False
|
||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
||||
Q_SCALE_CONSTANT: int = 200
|
||||
@ -600,6 +601,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
|
||||
"VLLM_DISABLE_COMPILE_CACHE":
|
||||
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
|
||||
"VLLM_ENABLE_V1_ADVANCE_STEP":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_ADVANCE_STEP", "0"))),
|
||||
|
||||
# If set, vllm will run in development mode, which will enable
|
||||
# some additional endpoints for developing and debugging,
|
||||
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@ -9,4 +9,5 @@ The example configurations provided are for the Mixtral model for TP2 on H100
|
||||
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
|
||||
N = 7168 and for TP4 we have N = 3584.
|
||||
|
||||
See `benchmark/kernels/benchmark_moe.py` on how to generate these config files.
|
||||
Please feel free to tune the configurations using scripts in `benchmarks/kernels/benchmark_moe.py`
|
||||
Some of the configurations files are copied from the SGLang repository. Thank you!
|
||||
|
||||
@ -113,9 +113,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
|
||||
# Padding the weight for better performance on ROCm
|
||||
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w13_weight.data),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w2_weight.data),
|
||||
requires_grad=False)
|
||||
# Lazy import to avoid importing triton.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
@ -124,8 +127,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
layer.w13_weight.data = shuffled_w13
|
||||
layer.w2_weight.data = shuffled_w2
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
if current_platform.is_cpu():
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
||||
|
||||
@ -929,15 +929,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
|
||||
shard_size = self._get_shard_size_mapping(loaded_shard_id)
|
||||
|
||||
# Note(simon): This is needed for Qwen3's fp8 quantization.
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
assert self.quant_method is not None
|
||||
assert hasattr(self.quant_method, "quant_config")
|
||||
weight_block_size = self.quant_method.quant_config.weight_block_size
|
||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||
shard_offset = (shard_offset + block_n - 1) // block_n
|
||||
shard_size = (shard_size + block_n - 1) // block_n
|
||||
|
||||
param.load_qkv_weight(loaded_weight=loaded_weight,
|
||||
num_heads=self.num_kv_head_replicas,
|
||||
shard_id=loaded_shard_id,
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Literal, Type, get_args
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
QuantizationMethods = Literal[
|
||||
QUANTIZATION_METHODS: List[str] = [
|
||||
"aqlm",
|
||||
"awq",
|
||||
"deepspeedfp",
|
||||
@ -15,6 +15,8 @@ QuantizationMethods = Literal[
|
||||
"fbgemm_fp8",
|
||||
"modelopt",
|
||||
"nvfp4",
|
||||
# The order of gptq methods is important for config.py iteration over
|
||||
# override_quantization_method(..)
|
||||
"marlin",
|
||||
"bitblas",
|
||||
"gguf",
|
||||
@ -34,7 +36,6 @@ QuantizationMethods = Literal[
|
||||
"moe_wna16",
|
||||
"torchao",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
# The customized quantization methods which will be added to this dict.
|
||||
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
|
||||
@ -110,7 +111,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
from .torchao import TorchAOConfig
|
||||
from .tpu_int8 import Int8TpuConfig
|
||||
|
||||
method_to_config: dict[str, Type[QuantizationConfig]] = {
|
||||
method_to_config: Dict[str, Type[QuantizationConfig]] = {
|
||||
"aqlm": AQLMConfig,
|
||||
"awq": AWQConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
@ -119,6 +120,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"modelopt": ModelOptFp8Config,
|
||||
"nvfp4": ModelOptNvFp4Config,
|
||||
# The order of gptq methods is important for config.py iteration over
|
||||
# override_quantization_method(..)
|
||||
"marlin": MarlinConfig,
|
||||
"bitblas": BitBLASConfig,
|
||||
"gguf": GGUFConfig,
|
||||
@ -147,7 +150,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
|
||||
__all__ = [
|
||||
"QuantizationConfig",
|
||||
"QuantizationMethods",
|
||||
"get_quantization_config",
|
||||
"QUANTIZATION_METHODS",
|
||||
]
|
||||
@ -72,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return 70
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "compressed-tensors"
|
||||
return "compressed_tensors"
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
|
||||
@ -156,7 +156,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_mi250_mi300
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300(
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and not on_mi250_mi300(
|
||||
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
|
||||
current_platform.get_cu_count())
|
||||
|
||||
@ -130,8 +130,8 @@ class RocmPlatform(Platform):
|
||||
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
||||
|
||||
supported_quantization: list[str] = [
|
||||
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
|
||||
"quark", "ptpc_fp8"
|
||||
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
|
||||
"fbgemm_fp8", "gguf", "quark", "ptpc_fp8"
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -30,7 +30,9 @@ class TpuPlatform(Platform):
|
||||
ray_device_key: str = "TPU"
|
||||
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
|
||||
|
||||
supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
|
||||
supported_quantization: list[str] = [
|
||||
"tpu_int8", "compressed-tensors", "compressed_tensors"
|
||||
]
|
||||
|
||||
additional_env_vars: list[str] = [
|
||||
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
|
||||
|
||||
@ -10,11 +10,9 @@ from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
@ -278,23 +276,13 @@ def make_local_attention_virtual_batches(
|
||||
block_table_local
|
||||
|
||||
|
||||
def _get_sliding_window_configs(
|
||||
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
|
||||
"""Get the set of all sliding window configs used in the model."""
|
||||
sliding_window_configs: set[Optional[tuple[int, int]]] = set()
|
||||
layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
for layer in layers.values():
|
||||
assert isinstance(layer.impl, FlashAttentionImpl)
|
||||
sliding_window_configs.add(layer.impl.sliding_window)
|
||||
return sliding_window_configs
|
||||
|
||||
|
||||
class FlashAttentionMetadataBuilder:
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner"):
|
||||
model_config = runner.model_config
|
||||
|
||||
self.runner = runner
|
||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
@ -302,11 +290,6 @@ class FlashAttentionMetadataBuilder:
|
||||
self.headdim = model_config.get_head_size()
|
||||
self.page_size = self.runner.block_size
|
||||
|
||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
@ -324,22 +307,6 @@ class FlashAttentionMetadataBuilder:
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True).long()
|
||||
|
||||
if self.aot_sliding_window is None:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
# For the AOT scheduler we need the sliding window value to be
|
||||
# constant for all layers to. We have to populate this on the first
|
||||
# build() call so the layers are constructed (cannot populate)
|
||||
# in __init__.
|
||||
if self.aot_schedule:
|
||||
sliding_window_configs = _get_sliding_window_configs(
|
||||
self.runner.vllm_config)
|
||||
if len(sliding_window_configs) == 1:
|
||||
sliding_window_config = sliding_window_configs.pop()
|
||||
if sliding_window_config is not None:
|
||||
self.aot_sliding_window = sliding_window_config
|
||||
elif len(sliding_window_configs) > 1:
|
||||
self.aot_schedule = False
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
if self.aot_schedule:
|
||||
@ -354,7 +321,6 @@ class FlashAttentionMetadataBuilder:
|
||||
page_size=self.page_size,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
causal=causal,
|
||||
window_size=self.aot_sliding_window,
|
||||
)
|
||||
return None
|
||||
|
||||
@ -406,7 +372,7 @@ class FlashAttentionMetadataBuilder:
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
prefix_scheduler_metadata = schedule(
|
||||
batch_size=1,
|
||||
batch_size=num_reqs,
|
||||
cu_query_lens=cu_prefix_query_lens,
|
||||
max_query_len=num_actual_tokens,
|
||||
seqlens=prefix_kv_lens,
|
||||
|
||||
@ -739,10 +739,7 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Return the cached request data to the queue so they can be reused.
|
||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
|
||||
# to _cached_reqs_data will cause a memory leak.
|
||||
if req_data.req_id not in self.finished_req_ids:
|
||||
self._cached_reqs_data[req_data.req_id].append(req_data)
|
||||
self._cached_reqs_data[req_data.req_id].append(req_data)
|
||||
|
||||
self.running = new_running
|
||||
engine_core_outputs = EngineCoreOutputs(
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -36,6 +37,9 @@ class BlockTable:
|
||||
self.block_table_np = self.block_table_cpu.numpy()
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
self.prev_num_reqs = 0
|
||||
self.is_updated = True
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
block_ids: list[int],
|
||||
@ -48,16 +52,22 @@ class BlockTable:
|
||||
self.num_blocks_per_row[row_idx] += num_blocks
|
||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||
|
||||
self.is_updated = True
|
||||
|
||||
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
||||
self.num_blocks_per_row[row_idx] = 0
|
||||
self.append_row(block_ids, row_idx)
|
||||
|
||||
self.is_updated = True
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks = self.num_blocks_per_row[src]
|
||||
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
|
||||
src, :num_blocks]
|
||||
self.num_blocks_per_row[tgt] = num_blocks
|
||||
|
||||
self.is_updated = True
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks_src = self.num_blocks_per_row[src]
|
||||
num_blocks_tgt = self.num_blocks_per_row[tgt]
|
||||
@ -66,14 +76,28 @@ class BlockTable:
|
||||
|
||||
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
|
||||
|
||||
self.is_updated = True
|
||||
|
||||
def commit(self, num_reqs: int) -> None:
|
||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
if envs.VLLM_ENABLE_V1_ADVANCE_STEP:
|
||||
# Incremental copy
|
||||
if self.prev_num_reqs != num_reqs or self.is_updated:
|
||||
self.block_table[:num_reqs].copy_(
|
||||
self.block_table_cpu[:num_reqs], non_blocking=True)
|
||||
|
||||
self.prev_num_reqs = num_reqs
|
||||
self.is_updated = False
|
||||
else:
|
||||
# Always copy
|
||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.block_table.fill_(0)
|
||||
self.block_table_cpu.fill_(0)
|
||||
|
||||
self.is_updated = True
|
||||
|
||||
def get_device_tensor(self) -> torch.Tensor:
|
||||
"""Ruturns the device tensor of the block table."""
|
||||
return self.block_table
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
@ -142,6 +143,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
weakref.proxy(self))
|
||||
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
||||
|
||||
if envs.VLLM_ENABLE_V1_ADVANCE_STEP:
|
||||
logger.info("Advance_step is enabled")
|
||||
if self.cascade_attn_enabled:
|
||||
logger.warning(
|
||||
"Disabling cascade attn (since advance_step is on)")
|
||||
self.cascade_attn_enabled = False
|
||||
else:
|
||||
logger.info("Advance_step is disabled")
|
||||
|
||||
# Multi-modal data support
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.uses_mrope = model_config.uses_mrope
|
||||
@ -271,16 +281,51 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
self.slot_mapping_gpu = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
||||
self.query_start_loc_gpu = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
self.seq_lens_gpu = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
# Cached
|
||||
self.prev_num_reqs = 0
|
||||
self.req_indices_gpu = torch.arange(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
self.req_indices_block_table_offsets_gpu = (
|
||||
self.req_indices_gpu * self.max_num_blocks_per_req)
|
||||
|
||||
self.num_scheduled_tokens_gpu = torch.ones(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
self.cu_num_tokens_gpu = torch.cumsum(self.num_scheduled_tokens_gpu, 0)
|
||||
|
||||
self.query_start_loc_gpu[0] = 0
|
||||
self.query_start_loc_gpu[1:self.max_num_reqs +
|
||||
1] = self.cu_num_tokens_gpu
|
||||
|
||||
self.logits_indices_gpu = self.query_start_loc_gpu[1:] - 1
|
||||
|
||||
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
self.prev_attn_metadata = None
|
||||
self.is_first_advance_decode = True
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
@ -485,6 +530,119 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if batch_changed or batch_reordered:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
def _advance_decode_step(
|
||||
self,
|
||||
scheduler_output,
|
||||
num_scheduled_tokens,
|
||||
):
|
||||
# print(" -- inside advance_decode_step")
|
||||
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens == num_reqs
|
||||
|
||||
# TODO: Add if needed
|
||||
# Get request indices.
|
||||
# E.g., num_reqs == 3 -> [0, 1, 2]
|
||||
# req_indices_gpu = self.req_indices_gpu[:num_reqs]
|
||||
# Get cu_sums
|
||||
# cu_num_tokens = self.cu_num_tokens_gpu[:num_reqs]
|
||||
|
||||
# Increment positions
|
||||
positions_gpu = self.positions[:total_num_scheduled_tokens]
|
||||
positions_gpu[:total_num_scheduled_tokens] += 1
|
||||
|
||||
# TODO: Verify MROPE is ok here
|
||||
# Calculate M-RoPE positions.
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
|
||||
# Set next tokens
|
||||
# (prev iteration tokens are cached in prev_sampled_token_ids tensor)
|
||||
assert self.prev_sampled_token_ids is not None
|
||||
self.input_ids[:total_num_scheduled_tokens] = \
|
||||
self.prev_sampled_token_ids[:,0]
|
||||
|
||||
# Calculate the slot mapping
|
||||
block_table_indices_gpu = (
|
||||
self.req_indices_block_table_offsets_gpu[:num_reqs] +
|
||||
positions_gpu // self.block_size)
|
||||
block_table_gpu = self.input_batch.block_table.get_device_tensor()
|
||||
# Note: The block table tensor is async copied from CPU to GPU
|
||||
# (inside the .commit() call) if was previously modified
|
||||
block_numbers_gpu = block_table_gpu.flatten()[block_table_indices_gpu]
|
||||
|
||||
block_offsets_gpu = positions_gpu % self.block_size
|
||||
|
||||
slot_mapping_gpu = self.slot_mapping_gpu[:total_num_scheduled_tokens]
|
||||
slot_mapping_gpu[:] = (block_numbers_gpu * self.block_size +
|
||||
block_offsets_gpu)
|
||||
|
||||
# Prepare the attention metadata.
|
||||
|
||||
# query_start_loc is always the same for all decode iterations
|
||||
query_start_loc_gpu = self.query_start_loc_gpu[:num_reqs + 1]
|
||||
|
||||
if self.uses_mrope:
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
||||
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
# TODO: Add cascade attn support
|
||||
# Verify cascade attention is disabled
|
||||
assert not self.cascade_attn_enabled
|
||||
|
||||
# TODO: Add support for other attn backends
|
||||
assert self.prev_attn_metadata is not None
|
||||
assert isinstance(self.prev_attn_metadata, FlashAttentionMetadata)
|
||||
|
||||
attn_metadata = self.prev_attn_metadata
|
||||
attn_metadata.max_seq_len += 1
|
||||
attn_metadata.query_start_loc = query_start_loc_gpu
|
||||
attn_metadata.seq_lens += 1
|
||||
attn_metadata.slot_mapping = slot_mapping_gpu
|
||||
|
||||
# print("attn_metadata.seq_lens: shape = {} data = {}".format(
|
||||
# attn_metadata.seq_lens.shape, attn_metadata.seq_lens))
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if not use_spec_decode:
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||||
# partial requests. While we should not sample any token
|
||||
# from these partial requests, we do so for simplicity.
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = self.logits_indices_gpu[:num_reqs]
|
||||
spec_decode_metadata = None
|
||||
else:
|
||||
# TODO: Check if spec_decode can be enabled here
|
||||
raise Exception("advance_step has no support for spec_decode yet")
|
||||
# # Get the number of draft tokens for each request.
|
||||
# # Iterate over the dictionary rather than all requests since
|
||||
# # not all requests have draft tokens.
|
||||
# num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
# for req_id, draft_token_ids in (
|
||||
# scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||
# req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
# num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
|
||||
# spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
# num_draft_tokens, cu_num_tokens)
|
||||
# logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
# TODO: Check if this works
|
||||
raise Exception("advance_step has no LORA support yet")
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
return attn_metadata, logits_indices, spec_decode_metadata
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@ -505,6 +663,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
|
||||
# Determine if advance step can be used
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
|
||||
is_flash_attn = self.prev_attn_metadata is not None and isinstance(
|
||||
self.prev_attn_metadata, FlashAttentionMetadata)
|
||||
|
||||
is_advance_decode = (envs.VLLM_ENABLE_V1_ADVANCE_STEP
|
||||
and self.prev_num_reqs == num_reqs
|
||||
and max_num_scheduled_tokens == 1
|
||||
and not use_spec_decode
|
||||
and not self.cascade_attn_enabled
|
||||
and is_flash_attn)
|
||||
|
||||
if is_advance_decode:
|
||||
if self.is_first_advance_decode:
|
||||
# The first time advance_step can be used,
|
||||
# we run the usual prepare, so that positions tensor
|
||||
# is initialized
|
||||
self.is_first_advance_decode = False
|
||||
else:
|
||||
# This is the fast-path advance_step
|
||||
# (all tensors are on the GPU and are updated on the GPU)
|
||||
(attn_metadata, logits_indices,
|
||||
spec_decode_metadata) = self._advance_decode_step(
|
||||
scheduler_output, num_scheduled_tokens)
|
||||
return attn_metadata, logits_indices, spec_decode_metadata
|
||||
else:
|
||||
self.is_first_advance_decode = True
|
||||
|
||||
self.prev_num_reqs = num_reqs
|
||||
|
||||
# Get request indices.
|
||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||
@ -523,6 +713,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
|
||||
|
||||
# Get positions.
|
||||
|
||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
arange,
|
||||
@ -599,6 +790,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=common_prefix_len,
|
||||
)
|
||||
self.prev_attn_metadata = attn_metadata
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
@ -1177,6 +1369,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Get the valid generated tokens.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
self.prev_sampled_token_ids = sampled_token_ids
|
||||
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
|
||||
Reference in New Issue
Block a user