[Bugfix] Fix quark fp8 format loading on AMD GPUs (#12612)

Signed-off-by: Felix Marty <felmarty@amd.com>
Signed-off-by: kewang2 <kewang2@amd.com>
Co-authored-by: kewang2 <kewang2@amd.com>
This commit is contained in:
fxmarty-amd
2025-05-08 11:53:53 +02:00
committed by GitHub
parent a463555dee
commit bb239a730f
2 changed files with 38 additions and 9 deletions

View File

@ -5,6 +5,7 @@ Run `pytest tests/quantization/test_quark.py`.
"""
import pytest
import torch
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
@ -63,3 +64,28 @@ def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
def test_quark_fp8_parity(vllm_runner):
quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"
llm_kwargs = {
"tensor_parallel_size": 1,
"enforce_eager": True,
"gpu_memory_utilization": 0.1
}
with (vllm_runner(quark_model_id, **llm_kwargs) as
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
quark_model = (quark_handle.model.llm_engine.model_executor.
driver_worker.model_runner.model)
quark_state_dict = quark_model.state_dict()
fp8_model = (fp8_handle.model.llm_engine.model_executor.driver_worker.
model_runner.model)
fp8_state_dict = fp8_model.state_dict()
assert fp8_state_dict.keys() == quark_state_dict.keys()
for key in fp8_state_dict:
assert torch.equal(fp8_state_dict[key], quark_state_dict[key])