Compare commits
2 Commits
v0.11.0rc2
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
| 1c2cf6926d | |||
| 3764fe0db3 |
@ -247,6 +247,60 @@ def test_compilation_config():
|
||||
and args.compilation_config.use_inductor)
|
||||
|
||||
|
||||
def test_compilation_config_json_and_dot_notation():
|
||||
"""Test that JSON and dot notation arguments can be combined correctly."""
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
|
||||
# Test case 1: JSON then dot notation
|
||||
args = parser.parse_args([
|
||||
'-O', '{"cudagraph_mode": "FULL_DECODE_ONLY"}',
|
||||
'-O.debug_dump_path=/home/alexm/debug_dump'
|
||||
])
|
||||
config = args.compilation_config
|
||||
assert config.cudagraph_mode.name == "FULL_DECODE_ONLY"
|
||||
assert config.debug_dump_path == "/home/alexm/debug_dump"
|
||||
|
||||
# Test case 2: Dot notation then JSON
|
||||
args = parser.parse_args([
|
||||
'-O.debug_dump_path=/home/alexm/debug_dump',
|
||||
'-O', '{"cudagraph_mode": "FULL_DECODE_ONLY"}'
|
||||
])
|
||||
config = args.compilation_config
|
||||
assert config.cudagraph_mode.name == "FULL_DECODE_ONLY"
|
||||
assert config.debug_dump_path == "/home/alexm/debug_dump"
|
||||
|
||||
# Test case 3: Multiple dot notation arguments
|
||||
args = parser.parse_args([
|
||||
'-O.cudagraph_mode=FULL_DECODE_ONLY',
|
||||
'-O.debug_dump_path=/home/alexm/debug_dump'
|
||||
])
|
||||
config = args.compilation_config
|
||||
assert config.cudagraph_mode.name == "FULL_DECODE_ONLY"
|
||||
assert config.debug_dump_path == "/home/alexm/debug_dump"
|
||||
|
||||
# Test case 4: Multiple JSON arguments
|
||||
args = parser.parse_args([
|
||||
'-O', '{"cudagraph_mode": "FULL_DECODE_ONLY"}',
|
||||
'-O', '{"debug_dump_path": "/home/alexm/debug_dump"}'
|
||||
])
|
||||
config = args.compilation_config
|
||||
assert config.cudagraph_mode.name == "FULL_DECODE_ONLY"
|
||||
assert config.debug_dump_path == "/home/alexm/debug_dump"
|
||||
|
||||
# Test case 5: Mix all formats
|
||||
args = parser.parse_args([
|
||||
'-O', '{"level": 1}',
|
||||
'-O.cudagraph_mode=FULL_DECODE_ONLY',
|
||||
'-O', '{"debug_dump_path": "/home/alexm/debug_dump"}',
|
||||
'-O.use_inductor=true'
|
||||
])
|
||||
config = args.compilation_config
|
||||
assert config.level == 1
|
||||
assert config.cudagraph_mode.name == "FULL_DECODE_ONLY"
|
||||
assert config.debug_dump_path == "/home/alexm/debug_dump"
|
||||
assert config.use_inductor is True
|
||||
|
||||
|
||||
def test_prefix_cache_default():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
args = parser.parse_args([])
|
||||
|
||||
@ -1910,6 +1910,24 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
arg_dict)
|
||||
duplicates |= {f'{key}.{d}' for d in arg_duplicates}
|
||||
delete.add(i)
|
||||
elif (processed_arg.startswith("-")
|
||||
and i + 1 < len(processed_args)
|
||||
and not processed_args[i + 1].startswith("-")):
|
||||
# Handle standalone JSON arguments like -O '{"key": "value"}'
|
||||
value_str = processed_args[i + 1]
|
||||
try:
|
||||
parsed_json = json.loads(value_str)
|
||||
if isinstance(parsed_json, dict):
|
||||
# This is a JSON argument, merge it with existing dict_args
|
||||
key = processed_arg
|
||||
arg_duplicates = recursive_dict_update(dict_args[key],
|
||||
parsed_json)
|
||||
duplicates |= {f'{key}.{d}' for d in arg_duplicates}
|
||||
delete.add(i)
|
||||
delete.add(i + 1)
|
||||
except json.decoder.JSONDecodeError:
|
||||
# Not a JSON argument, let it pass through normally
|
||||
pass
|
||||
# Filter out the dict args we set to None
|
||||
processed_args = [
|
||||
a for i, a in enumerate(processed_args) if i not in delete
|
||||
|
||||
Reference in New Issue
Block a user