[Bugfix] Fix quark fp8 format loading on AMD GPUs (#12612)
Signed-off-by: Felix Marty <felmarty@amd.com> Signed-off-by: kewang2 <kewang2@amd.com> Co-authored-by: kewang2 <kewang2@amd.com>
This commit is contained in:
@ -5,6 +5,7 @@ Run `pytest tests/quantization/test_quark.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
||||
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
|
||||
@ -63,3 +64,28 @@ def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
||||
|
||||
def test_quark_fp8_parity(vllm_runner):
|
||||
quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
|
||||
fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"
|
||||
|
||||
llm_kwargs = {
|
||||
"tensor_parallel_size": 1,
|
||||
"enforce_eager": True,
|
||||
"gpu_memory_utilization": 0.1
|
||||
}
|
||||
with (vllm_runner(quark_model_id, **llm_kwargs) as
|
||||
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
|
||||
quark_model = (quark_handle.model.llm_engine.model_executor.
|
||||
driver_worker.model_runner.model)
|
||||
quark_state_dict = quark_model.state_dict()
|
||||
|
||||
fp8_model = (fp8_handle.model.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
fp8_state_dict = fp8_model.state_dict()
|
||||
|
||||
assert fp8_state_dict.keys() == quark_state_dict.keys()
|
||||
|
||||
for key in fp8_state_dict:
|
||||
assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
|
||||
|
||||
Reference in New Issue
Block a user