[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)

The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
This commit is contained in:
Cody Yu
2024-05-22 13:28:20 -07:00
committed by GitHub
parent 8674f9880e
commit a3a73ab069
40 changed files with 284 additions and 158 deletions

View File

@ -153,15 +153,13 @@ if __name__ == '__main__':
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
"--kv-cache-dtype",
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,