[CLI] Improve CLI arg parsing for -O/--compilation-config (#20156)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
@ -239,32 +239,40 @@ def test_compilation_config():
|
||||
assert args.compilation_config == CompilationConfig()
|
||||
|
||||
# set to O3
|
||||
args = parser.parse_args(["-O3"])
|
||||
assert args.compilation_config.level == 3
|
||||
args = parser.parse_args(["-O0"])
|
||||
assert args.compilation_config.level == 0
|
||||
|
||||
# set to O 3 (space)
|
||||
args = parser.parse_args(["-O", "3"])
|
||||
assert args.compilation_config.level == 3
|
||||
args = parser.parse_args(["-O", "1"])
|
||||
assert args.compilation_config.level == 1
|
||||
|
||||
# set to O 3 (equals)
|
||||
args = parser.parse_args(["-O=3"])
|
||||
args = parser.parse_args(["-O=2"])
|
||||
assert args.compilation_config.level == 2
|
||||
|
||||
# set to O.level 3
|
||||
args = parser.parse_args(["-O.level", "3"])
|
||||
assert args.compilation_config.level == 3
|
||||
|
||||
# set to string form of a dict
|
||||
args = parser.parse_args([
|
||||
"--compilation-config",
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
|
||||
"-O",
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||
'"use_inductor": false}',
|
||||
])
|
||||
assert (args.compilation_config.level == 3 and
|
||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
|
||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||
and not args.compilation_config.use_inductor)
|
||||
|
||||
# set to string form of a dict
|
||||
args = parser.parse_args([
|
||||
"--compilation-config="
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||
'"use_inductor": true}',
|
||||
])
|
||||
assert (args.compilation_config.level == 3 and
|
||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
|
||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||
and args.compilation_config.use_inductor)
|
||||
|
||||
|
||||
def test_prefix_cache_default():
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import socket
|
||||
from collections.abc import AsyncIterator
|
||||
@ -142,6 +143,7 @@ def parser():
|
||||
parser.add_argument('--batch-size', type=int)
|
||||
parser.add_argument('--enable-feature', action='store_true')
|
||||
parser.add_argument('--hf-overrides', type=json.loads)
|
||||
parser.add_argument('-O', '--compilation-config', type=json.loads)
|
||||
return parser
|
||||
|
||||
|
||||
@ -265,6 +267,11 @@ def test_dict_args(parser):
|
||||
"val2",
|
||||
"--hf-overrides.key2.key4",
|
||||
"val3",
|
||||
# Test compile config and compilation level
|
||||
"-O.use_inductor=true",
|
||||
"-O.backend",
|
||||
"custom",
|
||||
"-O1",
|
||||
# Test = sign
|
||||
"--hf-overrides.key5=val4",
|
||||
# Test underscore to dash conversion
|
||||
@ -281,6 +288,13 @@ def test_dict_args(parser):
|
||||
"true",
|
||||
"--hf_overrides.key12.key13",
|
||||
"null",
|
||||
# Test '-' and '.' in value
|
||||
"--hf_overrides.key14.key15",
|
||||
"-minus.and.dot",
|
||||
# Test array values
|
||||
"-O.custom_ops+",
|
||||
"-quant_fp8",
|
||||
"-O.custom_ops+=+silu_mul,-rms_norm",
|
||||
]
|
||||
parsed_args = parser.parse_args(args)
|
||||
assert parsed_args.model_name == "something.something"
|
||||
@ -301,7 +315,40 @@ def test_dict_args(parser):
|
||||
"key12": {
|
||||
"key13": None,
|
||||
},
|
||||
"key14": {
|
||||
"key15": "-minus.and.dot",
|
||||
}
|
||||
}
|
||||
assert parsed_args.compilation_config == {
|
||||
"level": 1,
|
||||
"use_inductor": True,
|
||||
"backend": "custom",
|
||||
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
|
||||
}
|
||||
|
||||
|
||||
def test_duplicate_dict_args(caplog_vllm, parser):
|
||||
args = [
|
||||
"--model-name=something.something",
|
||||
"--hf-overrides.key1",
|
||||
"val1",
|
||||
"--hf-overrides.key1",
|
||||
"val2",
|
||||
"-O1",
|
||||
"-O.level",
|
||||
"2",
|
||||
"-O3",
|
||||
]
|
||||
|
||||
parsed_args = parser.parse_args(args)
|
||||
# Should be the last value
|
||||
assert parsed_args.hf_overrides == {"key1": "val2"}
|
||||
assert parsed_args.compilation_config == {"level": 3}
|
||||
|
||||
assert len(caplog_vllm.records) == 1
|
||||
assert "duplicate" in caplog_vllm.text
|
||||
assert "--hf-overrides.key1" in caplog_vllm.text
|
||||
assert "-O.level" in caplog_vllm.text
|
||||
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@ -4140,9 +4140,9 @@ class CompilationConfig:
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str) -> "CompilationConfig":
|
||||
"""Parse the CLI value for the compilation config."""
|
||||
if cli_value in ["0", "1", "2", "3"]:
|
||||
return cls(level=int(cli_value))
|
||||
"""Parse the CLI value for the compilation config.
|
||||
-O1, -O2, -O3, etc. is handled in FlexibleArgumentParser.
|
||||
"""
|
||||
return TypeAdapter(CompilationConfig).validate_json(cli_value)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@ -4303,17 +4303,16 @@ class VllmConfig:
|
||||
"""Quantization configuration."""
|
||||
compilation_config: CompilationConfig = field(
|
||||
default_factory=CompilationConfig)
|
||||
"""`torch.compile` configuration for the model.
|
||||
"""`torch.compile` and cudagraph capture configuration for the model.
|
||||
|
||||
When it is a number (0, 1, 2, 3), it will be interpreted as the
|
||||
optimization level.
|
||||
As a shorthand, `-O<n>` can be used to directly specify the compilation
|
||||
level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`).
|
||||
Currently, -O <n> and -O=<n> are supported as well but this will likely be
|
||||
removed in favor of clearer -O<n> syntax in the future.
|
||||
|
||||
NOTE: level 0 is the default level without any optimization. level 1 and 2
|
||||
are for internal testing only. level 3 is the recommended level for
|
||||
production.
|
||||
|
||||
Following the convention of traditional compilers, using `-O` without space
|
||||
is also supported. `-O3` is equivalent to `-O 3`.
|
||||
production, also default in V1.
|
||||
|
||||
You can specify the full compilation config like so:
|
||||
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
|
||||
|
||||
@ -202,7 +202,10 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
passed individually. For example, the following sets of arguments are
|
||||
equivalent:\n\n
|
||||
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
|
||||
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n"""
|
||||
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n
|
||||
Additionally, list elements can be passed individually using '+':
|
||||
- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
|
||||
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n"""
|
||||
if dataclass_cls is not None:
|
||||
|
||||
def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
|
||||
|
||||
@ -89,15 +89,15 @@ MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_SWA = \
|
||||
"Sliding window attention for encoder/decoder models " + \
|
||||
"is not currently supported."
|
||||
"is not currently supported."
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
|
||||
"Prefix caching for encoder/decoder models " + \
|
||||
"is not currently supported."
|
||||
"is not currently supported."
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
|
||||
"Chunked prefill for encoder/decoder models " + \
|
||||
"is not currently supported."
|
||||
"is not currently supported."
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
|
||||
"Models with logits_soft_cap "
|
||||
@ -752,7 +752,7 @@ def _generate_random_fp8(
|
||||
# to generate random data for fp8 data.
|
||||
# For example, s.11111.00 in fp8e5m2 format represents Inf.
|
||||
# | E4M3 | E5M2
|
||||
#-----|-------------|-------------------
|
||||
# -----|-------------|-------------------
|
||||
# Inf | N/A | s.11111.00
|
||||
# NaN | s.1111.111 | s.11111.{01,10,11}
|
||||
from vllm import _custom_ops as ops
|
||||
@ -840,7 +840,6 @@ def create_kv_caches_with_random(
|
||||
seed: Optional[int] = None,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
|
||||
if cache_dtype == "fp8" and head_size % 16:
|
||||
raise ValueError(
|
||||
f"Does not support key cache of type fp8 with head_size {head_size}"
|
||||
@ -1205,7 +1204,6 @@ def deprecate_args(
|
||||
is_deprecated: Union[bool, Callable[[], bool]] = True,
|
||||
additional_message: Optional[str] = None,
|
||||
) -> Callable[[F], F]:
|
||||
|
||||
if not callable(is_deprecated):
|
||||
is_deprecated = partial(identity, is_deprecated)
|
||||
|
||||
@ -1355,7 +1353,7 @@ def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
|
||||
return weak_bound
|
||||
|
||||
|
||||
#From: https://stackoverflow.com/a/4104188/2749989
|
||||
# From: https://stackoverflow.com/a/4104188/2749989
|
||||
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
@ -1474,7 +1472,7 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
|
||||
# Convert underscores to dashes and vice versa in argument names
|
||||
processed_args = list[str]()
|
||||
for arg in args:
|
||||
for i, arg in enumerate(args):
|
||||
if arg.startswith('--'):
|
||||
if '=' in arg:
|
||||
key, value = arg.split('=', 1)
|
||||
@ -1483,10 +1481,17 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
else:
|
||||
key = pattern.sub(repl, arg, count=1)
|
||||
processed_args.append(key)
|
||||
elif arg.startswith('-O') and arg != '-O' and len(arg) == 2:
|
||||
# allow -O flag to be used without space, e.g. -O3
|
||||
processed_args.append('-O')
|
||||
processed_args.append(arg[2:])
|
||||
elif arg.startswith('-O') and arg != '-O' and arg[2] != '.':
|
||||
# allow -O flag to be used without space, e.g. -O3 or -Odecode
|
||||
# -O.<...> handled later
|
||||
# also handle -O=<level> here
|
||||
level = arg[3:] if arg[2] == '=' else arg[2:]
|
||||
processed_args.append(f'-O.level={level}')
|
||||
elif arg == '-O' and i + 1 < len(args) and args[i + 1] in {
|
||||
"0", "1", "2", "3"
|
||||
}:
|
||||
# Convert -O <n> to -O.level <n>
|
||||
processed_args.append('-O.level')
|
||||
else:
|
||||
processed_args.append(arg)
|
||||
|
||||
@ -1504,27 +1509,44 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
def recursive_dict_update(
|
||||
original: dict[str, Any],
|
||||
update: dict[str, Any],
|
||||
):
|
||||
"""Recursively updates a dictionary with another dictionary."""
|
||||
) -> set[str]:
|
||||
"""Recursively updates a dictionary with another dictionary.
|
||||
Returns a set of duplicate keys that were overwritten.
|
||||
"""
|
||||
duplicates = set[str]()
|
||||
for k, v in update.items():
|
||||
if isinstance(v, dict) and isinstance(original.get(k), dict):
|
||||
recursive_dict_update(original[k], v)
|
||||
nested_duplicates = recursive_dict_update(original[k], v)
|
||||
duplicates |= {f"{k}.{d}" for d in nested_duplicates}
|
||||
elif isinstance(v, list) and isinstance(original.get(k), list):
|
||||
original[k] += v
|
||||
else:
|
||||
if k in original:
|
||||
duplicates.add(k)
|
||||
original[k] = v
|
||||
return duplicates
|
||||
|
||||
delete = set[int]()
|
||||
dict_args = defaultdict[str, dict[str, Any]](dict)
|
||||
duplicates = set[str]()
|
||||
for i, processed_arg in enumerate(processed_args):
|
||||
if processed_arg.startswith("--") and "." in processed_arg:
|
||||
if i in delete: # skip if value from previous arg
|
||||
continue
|
||||
|
||||
if processed_arg.startswith("-") and "." in processed_arg:
|
||||
if "=" in processed_arg:
|
||||
processed_arg, value_str = processed_arg.split("=", 1)
|
||||
if "." not in processed_arg:
|
||||
# False positive, . was only in the value
|
||||
# False positive, '.' was only in the value
|
||||
continue
|
||||
else:
|
||||
value_str = processed_args[i + 1]
|
||||
delete.add(i + 1)
|
||||
|
||||
if processed_arg.endswith("+"):
|
||||
processed_arg = processed_arg[:-1]
|
||||
value_str = json.dumps(list(value_str.split(",")))
|
||||
|
||||
key, *keys = processed_arg.split(".")
|
||||
try:
|
||||
value = json.loads(value_str)
|
||||
@ -1533,12 +1555,17 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
|
||||
# Merge all values with the same key into a single dict
|
||||
arg_dict = create_nested_dict(keys, value)
|
||||
recursive_dict_update(dict_args[key], arg_dict)
|
||||
arg_duplicates = recursive_dict_update(dict_args[key],
|
||||
arg_dict)
|
||||
duplicates |= {f'{key}.{d}' for d in arg_duplicates}
|
||||
delete.add(i)
|
||||
# Filter out the dict args we set to None
|
||||
processed_args = [
|
||||
a for i, a in enumerate(processed_args) if i not in delete
|
||||
]
|
||||
if duplicates:
|
||||
logger.warning("Found duplicate keys %s", ", ".join(duplicates))
|
||||
|
||||
# Add the dict args back as if they were originally passed as JSON
|
||||
for dict_arg, dict_value in dict_args.items():
|
||||
processed_args.append(dict_arg)
|
||||
@ -2405,7 +2432,7 @@ def memory_profiling(
|
||||
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
|
||||
|
||||
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
|
||||
""" # noqa
|
||||
""" # noqa
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
Reference in New Issue
Block a user