Compare commits

...

2 Commits

2 changed files with 72 additions and 0 deletions

View File

@ -247,6 +247,60 @@ def test_compilation_config():
and args.compilation_config.use_inductor)
def test_compilation_config_json_and_dot_notation():
"""Test that JSON and dot notation arguments can be combined correctly."""
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
# Test case 1: JSON then dot notation
args = parser.parse_args([
'-O', '{"cudagraph_mode": "FULL_DECODE_ONLY"}',
'-O.debug_dump_path=/home/alexm/debug_dump'
])
config = args.compilation_config
assert config.cudagraph_mode.name == "FULL_DECODE_ONLY"
assert config.debug_dump_path == "/home/alexm/debug_dump"
# Test case 2: Dot notation then JSON
args = parser.parse_args([
'-O.debug_dump_path=/home/alexm/debug_dump',
'-O', '{"cudagraph_mode": "FULL_DECODE_ONLY"}'
])
config = args.compilation_config
assert config.cudagraph_mode.name == "FULL_DECODE_ONLY"
assert config.debug_dump_path == "/home/alexm/debug_dump"
# Test case 3: Multiple dot notation arguments
args = parser.parse_args([
'-O.cudagraph_mode=FULL_DECODE_ONLY',
'-O.debug_dump_path=/home/alexm/debug_dump'
])
config = args.compilation_config
assert config.cudagraph_mode.name == "FULL_DECODE_ONLY"
assert config.debug_dump_path == "/home/alexm/debug_dump"
# Test case 4: Multiple JSON arguments
args = parser.parse_args([
'-O', '{"cudagraph_mode": "FULL_DECODE_ONLY"}',
'-O', '{"debug_dump_path": "/home/alexm/debug_dump"}'
])
config = args.compilation_config
assert config.cudagraph_mode.name == "FULL_DECODE_ONLY"
assert config.debug_dump_path == "/home/alexm/debug_dump"
# Test case 5: Mix all formats
args = parser.parse_args([
'-O', '{"level": 1}',
'-O.cudagraph_mode=FULL_DECODE_ONLY',
'-O', '{"debug_dump_path": "/home/alexm/debug_dump"}',
'-O.use_inductor=true'
])
config = args.compilation_config
assert config.level == 1
assert config.cudagraph_mode.name == "FULL_DECODE_ONLY"
assert config.debug_dump_path == "/home/alexm/debug_dump"
assert config.use_inductor is True
def test_prefix_cache_default():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([])

View File

@ -1910,6 +1910,24 @@ class FlexibleArgumentParser(ArgumentParser):
arg_dict)
duplicates |= {f'{key}.{d}' for d in arg_duplicates}
delete.add(i)
elif (processed_arg.startswith("-")
and i + 1 < len(processed_args)
and not processed_args[i + 1].startswith("-")):
# Handle standalone JSON arguments like -O '{"key": "value"}'
value_str = processed_args[i + 1]
try:
parsed_json = json.loads(value_str)
if isinstance(parsed_json, dict):
# This is a JSON argument, merge it with existing dict_args
key = processed_arg
arg_duplicates = recursive_dict_update(dict_args[key],
parsed_json)
duplicates |= {f'{key}.{d}' for d in arg_duplicates}
delete.add(i)
delete.add(i + 1)
except json.decoder.JSONDecodeError:
# Not a JSON argument, let it pass through normally
pass
# Filter out the dict args we set to None
processed_args = [
a for i, a in enumerate(processed_args) if i not in delete