diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index cfbc7c245f..847f150bd6 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -239,32 +239,40 @@ def test_compilation_config(): assert args.compilation_config == CompilationConfig() # set to O3 - args = parser.parse_args(["-O3"]) - assert args.compilation_config.level == 3 + args = parser.parse_args(["-O0"]) + assert args.compilation_config.level == 0 # set to O 3 (space) - args = parser.parse_args(["-O", "3"]) - assert args.compilation_config.level == 3 + args = parser.parse_args(["-O", "1"]) + assert args.compilation_config.level == 1 # set to O 3 (equals) - args = parser.parse_args(["-O=3"]) + args = parser.parse_args(["-O=2"]) + assert args.compilation_config.level == 2 + + # set to O.level 3 + args = parser.parse_args(["-O.level", "3"]) assert args.compilation_config.level == 3 # set to string form of a dict args = parser.parse_args([ - "--compilation-config", - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', + "-O", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '"use_inductor": false}', ]) assert (args.compilation_config.level == 3 and - args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) + args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] + and not args.compilation_config.use_inductor) # set to string form of a dict args = parser.parse_args([ "--compilation-config=" - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '"use_inductor": true}', ]) assert (args.compilation_config.level == 3 and - args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) + args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] + and args.compilation_config.use_inductor) def test_prefix_cache_default(): diff --git a/tests/test_utils.py b/tests/test_utils.py index 913188455d..36db8202ba 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,6 +5,7 @@ import asyncio import hashlib import json +import logging import pickle import socket from collections.abc import AsyncIterator @@ -142,6 +143,7 @@ def parser(): parser.add_argument('--batch-size', type=int) parser.add_argument('--enable-feature', action='store_true') parser.add_argument('--hf-overrides', type=json.loads) + parser.add_argument('-O', '--compilation-config', type=json.loads) return parser @@ -265,6 +267,11 @@ def test_dict_args(parser): "val2", "--hf-overrides.key2.key4", "val3", + # Test compile config and compilation level + "-O.use_inductor=true", + "-O.backend", + "custom", + "-O1", # Test = sign "--hf-overrides.key5=val4", # Test underscore to dash conversion @@ -281,6 +288,13 @@ def test_dict_args(parser): "true", "--hf_overrides.key12.key13", "null", + # Test '-' and '.' in value + "--hf_overrides.key14.key15", + "-minus.and.dot", + # Test array values + "-O.custom_ops+", + "-quant_fp8", + "-O.custom_ops+=+silu_mul,-rms_norm", ] parsed_args = parser.parse_args(args) assert parsed_args.model_name == "something.something" @@ -301,7 +315,40 @@ def test_dict_args(parser): "key12": { "key13": None, }, + "key14": { + "key15": "-minus.and.dot", + } } + assert parsed_args.compilation_config == { + "level": 1, + "use_inductor": True, + "backend": "custom", + "custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"], + } + + +def test_duplicate_dict_args(caplog_vllm, parser): + args = [ + "--model-name=something.something", + "--hf-overrides.key1", + "val1", + "--hf-overrides.key1", + "val2", + "-O1", + "-O.level", + "2", + "-O3", + ] + + parsed_args = parser.parse_args(args) + # Should be the last value + assert parsed_args.hf_overrides == {"key1": "val2"} + assert parsed_args.compilation_config == {"level": 3} + + assert len(caplog_vllm.records) == 1 + assert "duplicate" in caplog_vllm.text + assert "--hf-overrides.key1" in caplog_vllm.text + assert "-O.level" in caplog_vllm.text # yapf: enable diff --git a/vllm/config.py b/vllm/config.py index 57b9df2364..46a5bf34f6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4140,9 +4140,9 @@ class CompilationConfig: @classmethod def from_cli(cls, cli_value: str) -> "CompilationConfig": - """Parse the CLI value for the compilation config.""" - if cli_value in ["0", "1", "2", "3"]: - return cls(level=int(cli_value)) + """Parse the CLI value for the compilation config. + -O1, -O2, -O3, etc. is handled in FlexibleArgumentParser. + """ return TypeAdapter(CompilationConfig).validate_json(cli_value) def __post_init__(self) -> None: @@ -4303,17 +4303,16 @@ class VllmConfig: """Quantization configuration.""" compilation_config: CompilationConfig = field( default_factory=CompilationConfig) - """`torch.compile` configuration for the model. + """`torch.compile` and cudagraph capture configuration for the model. - When it is a number (0, 1, 2, 3), it will be interpreted as the - optimization level. + As a shorthand, `-O` can be used to directly specify the compilation + level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`). + Currently, -O and -O= are supported as well but this will likely be + removed in favor of clearer -O syntax in the future. NOTE: level 0 is the default level without any optimization. level 1 and 2 are for internal testing only. level 3 is the recommended level for - production. - - Following the convention of traditional compilers, using `-O` without space - is also supported. `-O3` is equivalent to `-O 3`. + production, also default in V1. You can specify the full compilation config like so: `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6c908f88b9..2d3783363c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -202,7 +202,10 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: passed individually. For example, the following sets of arguments are equivalent:\n\n - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n - - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n""" + - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n + Additionally, list elements can be passed individually using '+': + - `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n + - `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n""" if dataclass_cls is not None: def parse_dataclass(val: str, cls=dataclass_cls) -> Any: diff --git a/vllm/utils.py b/vllm/utils.py index 689102281c..60e560c70a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -89,15 +89,15 @@ MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 STR_NOT_IMPL_ENC_DEC_SWA = \ "Sliding window attention for encoder/decoder models " + \ - "is not currently supported." + "is not currently supported." STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ "Prefix caching for encoder/decoder models " + \ - "is not currently supported." + "is not currently supported." STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ "Chunked prefill for encoder/decoder models " + \ - "is not currently supported." + "is not currently supported." STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = ( "Models with logits_soft_cap " @@ -752,7 +752,7 @@ def _generate_random_fp8( # to generate random data for fp8 data. # For example, s.11111.00 in fp8e5m2 format represents Inf. # | E4M3 | E5M2 - #-----|-------------|------------------- + # -----|-------------|------------------- # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} from vllm import _custom_ops as ops @@ -840,7 +840,6 @@ def create_kv_caches_with_random( seed: Optional[int] = None, device: Optional[str] = "cuda", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - if cache_dtype == "fp8" and head_size % 16: raise ValueError( f"Does not support key cache of type fp8 with head_size {head_size}" @@ -1205,7 +1204,6 @@ def deprecate_args( is_deprecated: Union[bool, Callable[[], bool]] = True, additional_message: Optional[str] = None, ) -> Callable[[F], F]: - if not callable(is_deprecated): is_deprecated = partial(identity, is_deprecated) @@ -1355,7 +1353,7 @@ def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: return weak_bound -#From: https://stackoverflow.com/a/4104188/2749989 +# From: https://stackoverflow.com/a/4104188/2749989 def run_once(f: Callable[P, None]) -> Callable[P, None]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: @@ -1474,7 +1472,7 @@ class FlexibleArgumentParser(ArgumentParser): # Convert underscores to dashes and vice versa in argument names processed_args = list[str]() - for arg in args: + for i, arg in enumerate(args): if arg.startswith('--'): if '=' in arg: key, value = arg.split('=', 1) @@ -1483,10 +1481,17 @@ class FlexibleArgumentParser(ArgumentParser): else: key = pattern.sub(repl, arg, count=1) processed_args.append(key) - elif arg.startswith('-O') and arg != '-O' and len(arg) == 2: - # allow -O flag to be used without space, e.g. -O3 - processed_args.append('-O') - processed_args.append(arg[2:]) + elif arg.startswith('-O') and arg != '-O' and arg[2] != '.': + # allow -O flag to be used without space, e.g. -O3 or -Odecode + # -O.<...> handled later + # also handle -O= here + level = arg[3:] if arg[2] == '=' else arg[2:] + processed_args.append(f'-O.level={level}') + elif arg == '-O' and i + 1 < len(args) and args[i + 1] in { + "0", "1", "2", "3" + }: + # Convert -O to -O.level + processed_args.append('-O.level') else: processed_args.append(arg) @@ -1504,27 +1509,44 @@ class FlexibleArgumentParser(ArgumentParser): def recursive_dict_update( original: dict[str, Any], update: dict[str, Any], - ): - """Recursively updates a dictionary with another dictionary.""" + ) -> set[str]: + """Recursively updates a dictionary with another dictionary. + Returns a set of duplicate keys that were overwritten. + """ + duplicates = set[str]() for k, v in update.items(): if isinstance(v, dict) and isinstance(original.get(k), dict): - recursive_dict_update(original[k], v) + nested_duplicates = recursive_dict_update(original[k], v) + duplicates |= {f"{k}.{d}" for d in nested_duplicates} + elif isinstance(v, list) and isinstance(original.get(k), list): + original[k] += v else: + if k in original: + duplicates.add(k) original[k] = v + return duplicates delete = set[int]() dict_args = defaultdict[str, dict[str, Any]](dict) + duplicates = set[str]() for i, processed_arg in enumerate(processed_args): - if processed_arg.startswith("--") and "." in processed_arg: + if i in delete: # skip if value from previous arg + continue + + if processed_arg.startswith("-") and "." in processed_arg: if "=" in processed_arg: processed_arg, value_str = processed_arg.split("=", 1) if "." not in processed_arg: - # False positive, . was only in the value + # False positive, '.' was only in the value continue else: value_str = processed_args[i + 1] delete.add(i + 1) + if processed_arg.endswith("+"): + processed_arg = processed_arg[:-1] + value_str = json.dumps(list(value_str.split(","))) + key, *keys = processed_arg.split(".") try: value = json.loads(value_str) @@ -1533,12 +1555,17 @@ class FlexibleArgumentParser(ArgumentParser): # Merge all values with the same key into a single dict arg_dict = create_nested_dict(keys, value) - recursive_dict_update(dict_args[key], arg_dict) + arg_duplicates = recursive_dict_update(dict_args[key], + arg_dict) + duplicates |= {f'{key}.{d}' for d in arg_duplicates} delete.add(i) # Filter out the dict args we set to None processed_args = [ a for i, a in enumerate(processed_args) if i not in delete ] + if duplicates: + logger.warning("Found duplicate keys %s", ", ".join(duplicates)) + # Add the dict args back as if they were originally passed as JSON for dict_arg, dict_value in dict_args.items(): processed_args.append(dict_arg) @@ -2405,7 +2432,7 @@ def memory_profiling( The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). - """ # noqa + """ # noqa gc.collect() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats()