[CLI] Improve CLI arg parsing for -O/--compilation-config (#20156)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-06-30 21:03:13 -04:00
committed by GitHub
parent ded1fb635b
commit 6d42ce8315
5 changed files with 124 additions and 40 deletions

View File

@ -239,32 +239,40 @@ def test_compilation_config():
assert args.compilation_config == CompilationConfig()
# set to O3
args = parser.parse_args(["-O3"])
assert args.compilation_config.level == 3
args = parser.parse_args(["-O0"])
assert args.compilation_config.level == 0
# set to O 3 (space)
args = parser.parse_args(["-O", "3"])
assert args.compilation_config.level == 3
args = parser.parse_args(["-O", "1"])
assert args.compilation_config.level == 1
# set to O 3 (equals)
args = parser.parse_args(["-O=3"])
args = parser.parse_args(["-O=2"])
assert args.compilation_config.level == 2
# set to O.level 3
args = parser.parse_args(["-O.level", "3"])
assert args.compilation_config.level == 3
# set to string form of a dict
args = parser.parse_args([
"--compilation-config",
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
"-O",
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
'"use_inductor": false}',
])
assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
and not args.compilation_config.use_inductor)
# set to string form of a dict
args = parser.parse_args([
"--compilation-config="
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
'"use_inductor": true}',
])
assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
and args.compilation_config.use_inductor)
def test_prefix_cache_default():

View File

@ -5,6 +5,7 @@
import asyncio
import hashlib
import json
import logging
import pickle
import socket
from collections.abc import AsyncIterator
@ -142,6 +143,7 @@ def parser():
parser.add_argument('--batch-size', type=int)
parser.add_argument('--enable-feature', action='store_true')
parser.add_argument('--hf-overrides', type=json.loads)
parser.add_argument('-O', '--compilation-config', type=json.loads)
return parser
@ -265,6 +267,11 @@ def test_dict_args(parser):
"val2",
"--hf-overrides.key2.key4",
"val3",
# Test compile config and compilation level
"-O.use_inductor=true",
"-O.backend",
"custom",
"-O1",
# Test = sign
"--hf-overrides.key5=val4",
# Test underscore to dash conversion
@ -281,6 +288,13 @@ def test_dict_args(parser):
"true",
"--hf_overrides.key12.key13",
"null",
# Test '-' and '.' in value
"--hf_overrides.key14.key15",
"-minus.and.dot",
# Test array values
"-O.custom_ops+",
"-quant_fp8",
"-O.custom_ops+=+silu_mul,-rms_norm",
]
parsed_args = parser.parse_args(args)
assert parsed_args.model_name == "something.something"
@ -301,7 +315,40 @@ def test_dict_args(parser):
"key12": {
"key13": None,
},
"key14": {
"key15": "-minus.and.dot",
}
}
assert parsed_args.compilation_config == {
"level": 1,
"use_inductor": True,
"backend": "custom",
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
}
def test_duplicate_dict_args(caplog_vllm, parser):
args = [
"--model-name=something.something",
"--hf-overrides.key1",
"val1",
"--hf-overrides.key1",
"val2",
"-O1",
"-O.level",
"2",
"-O3",
]
parsed_args = parser.parse_args(args)
# Should be the last value
assert parsed_args.hf_overrides == {"key1": "val2"}
assert parsed_args.compilation_config == {"level": 3}
assert len(caplog_vllm.records) == 1
assert "duplicate" in caplog_vllm.text
assert "--hf-overrides.key1" in caplog_vllm.text
assert "-O.level" in caplog_vllm.text
# yapf: enable

View File

@ -4140,9 +4140,9 @@ class CompilationConfig:
@classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config."""
if cli_value in ["0", "1", "2", "3"]:
return cls(level=int(cli_value))
"""Parse the CLI value for the compilation config.
-O1, -O2, -O3, etc. is handled in FlexibleArgumentParser.
"""
return TypeAdapter(CompilationConfig).validate_json(cli_value)
def __post_init__(self) -> None:
@ -4303,17 +4303,16 @@ class VllmConfig:
"""Quantization configuration."""
compilation_config: CompilationConfig = field(
default_factory=CompilationConfig)
"""`torch.compile` configuration for the model.
"""`torch.compile` and cudagraph capture configuration for the model.
When it is a number (0, 1, 2, 3), it will be interpreted as the
optimization level.
As a shorthand, `-O<n>` can be used to directly specify the compilation
level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`).
Currently, -O <n> and -O=<n> are supported as well but this will likely be
removed in favor of clearer -O<n> syntax in the future.
NOTE: level 0 is the default level without any optimization. level 1 and 2
are for internal testing only. level 3 is the recommended level for
production.
Following the convention of traditional compilers, using `-O` without space
is also supported. `-O3` is equivalent to `-O 3`.
production, also default in V1.
You can specify the full compilation config like so:
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`

View File

@ -202,7 +202,10 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
passed individually. For example, the following sets of arguments are
equivalent:\n\n
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n"""
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n
Additionally, list elements can be passed individually using '+':
- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n"""
if dataclass_cls is not None:
def parse_dataclass(val: str, cls=dataclass_cls) -> Any:

View File

@ -89,15 +89,15 @@ MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
STR_NOT_IMPL_ENC_DEC_SWA = \
"Sliding window attention for encoder/decoder models " + \
"is not currently supported."
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
"Prefix caching for encoder/decoder models " + \
"is not currently supported."
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
"Chunked prefill for encoder/decoder models " + \
"is not currently supported."
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
"Models with logits_soft_cap "
@ -752,7 +752,7 @@ def _generate_random_fp8(
# to generate random data for fp8 data.
# For example, s.11111.00 in fp8e5m2 format represents Inf.
# | E4M3 | E5M2
#-----|-------------|-------------------
# -----|-------------|-------------------
# Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11}
from vllm import _custom_ops as ops
@ -840,7 +840,6 @@ def create_kv_caches_with_random(
seed: Optional[int] = None,
device: Optional[str] = "cuda",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
if cache_dtype == "fp8" and head_size % 16:
raise ValueError(
f"Does not support key cache of type fp8 with head_size {head_size}"
@ -1205,7 +1204,6 @@ def deprecate_args(
is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None,
) -> Callable[[F], F]:
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
@ -1355,7 +1353,7 @@ def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
return weak_bound
#From: https://stackoverflow.com/a/4104188/2749989
# From: https://stackoverflow.com/a/4104188/2749989
def run_once(f: Callable[P, None]) -> Callable[P, None]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
@ -1474,7 +1472,7 @@ class FlexibleArgumentParser(ArgumentParser):
# Convert underscores to dashes and vice versa in argument names
processed_args = list[str]()
for arg in args:
for i, arg in enumerate(args):
if arg.startswith('--'):
if '=' in arg:
key, value = arg.split('=', 1)
@ -1483,10 +1481,17 @@ class FlexibleArgumentParser(ArgumentParser):
else:
key = pattern.sub(repl, arg, count=1)
processed_args.append(key)
elif arg.startswith('-O') and arg != '-O' and len(arg) == 2:
# allow -O flag to be used without space, e.g. -O3
processed_args.append('-O')
processed_args.append(arg[2:])
elif arg.startswith('-O') and arg != '-O' and arg[2] != '.':
# allow -O flag to be used without space, e.g. -O3 or -Odecode
# -O.<...> handled later
# also handle -O=<level> here
level = arg[3:] if arg[2] == '=' else arg[2:]
processed_args.append(f'-O.level={level}')
elif arg == '-O' and i + 1 < len(args) and args[i + 1] in {
"0", "1", "2", "3"
}:
# Convert -O <n> to -O.level <n>
processed_args.append('-O.level')
else:
processed_args.append(arg)
@ -1504,27 +1509,44 @@ class FlexibleArgumentParser(ArgumentParser):
def recursive_dict_update(
original: dict[str, Any],
update: dict[str, Any],
):
"""Recursively updates a dictionary with another dictionary."""
) -> set[str]:
"""Recursively updates a dictionary with another dictionary.
Returns a set of duplicate keys that were overwritten.
"""
duplicates = set[str]()
for k, v in update.items():
if isinstance(v, dict) and isinstance(original.get(k), dict):
recursive_dict_update(original[k], v)
nested_duplicates = recursive_dict_update(original[k], v)
duplicates |= {f"{k}.{d}" for d in nested_duplicates}
elif isinstance(v, list) and isinstance(original.get(k), list):
original[k] += v
else:
if k in original:
duplicates.add(k)
original[k] = v
return duplicates
delete = set[int]()
dict_args = defaultdict[str, dict[str, Any]](dict)
duplicates = set[str]()
for i, processed_arg in enumerate(processed_args):
if processed_arg.startswith("--") and "." in processed_arg:
if i in delete: # skip if value from previous arg
continue
if processed_arg.startswith("-") and "." in processed_arg:
if "=" in processed_arg:
processed_arg, value_str = processed_arg.split("=", 1)
if "." not in processed_arg:
# False positive, . was only in the value
# False positive, '.' was only in the value
continue
else:
value_str = processed_args[i + 1]
delete.add(i + 1)
if processed_arg.endswith("+"):
processed_arg = processed_arg[:-1]
value_str = json.dumps(list(value_str.split(",")))
key, *keys = processed_arg.split(".")
try:
value = json.loads(value_str)
@ -1533,12 +1555,17 @@ class FlexibleArgumentParser(ArgumentParser):
# Merge all values with the same key into a single dict
arg_dict = create_nested_dict(keys, value)
recursive_dict_update(dict_args[key], arg_dict)
arg_duplicates = recursive_dict_update(dict_args[key],
arg_dict)
duplicates |= {f'{key}.{d}' for d in arg_duplicates}
delete.add(i)
# Filter out the dict args we set to None
processed_args = [
a for i, a in enumerate(processed_args) if i not in delete
]
if duplicates:
logger.warning("Found duplicate keys %s", ", ".join(duplicates))
# Add the dict args back as if they were originally passed as JSON
for dict_arg, dict_value in dict_args.items():
processed_args.append(dict_arg)
@ -2405,7 +2432,7 @@ def memory_profiling(
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
""" # noqa
""" # noqa
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()