[V1] Support LLM.apply_model (#18465)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@ -43,12 +43,9 @@ ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This module relies on V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
if not current_platform.is_cpu():
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -176,10 +173,11 @@ def test_compressed_tensors_w8a8_logprobs(
|
||||
|
||||
dtype = "bfloat16"
|
||||
|
||||
# skip language translation prompt for the static per tensor asym model
|
||||
if (model_path ==
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
|
||||
): # noqa: E501
|
||||
# skip language translation prompt for the static per tensor models
|
||||
if model_path in (
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
|
||||
):
|
||||
example_prompts = example_prompts[0:-1]
|
||||
|
||||
with hf_runner(model_path, dtype=dtype) as hf_model:
|
||||
|
||||
@ -60,8 +60,8 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
|
||||
|
||||
def check_model(model):
|
||||
@ -104,8 +104,8 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
if force_marlin:
|
||||
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
|
||||
|
||||
@ -31,41 +31,46 @@ MODEL_QUANT = [
|
||||
@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT)
|
||||
def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool,
|
||||
monkeypatch):
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048)
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else (
|
||||
GPTQLinearMethod)
|
||||
|
||||
for name, submodule in (vllm_model.llm.llm_engine.model_executor.
|
||||
driver_worker.model_runner.model.named_modules()):
|
||||
if name == "lm_head":
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
elif name == 'model.layers.0.self_attn.qkv_proj':
|
||||
# The first layer is quantized using bits=4, group_size=128
|
||||
# desc_act=True
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert config.weight_bits == 4
|
||||
assert config.group_size == 128
|
||||
assert config.desc_act
|
||||
elif name == 'model.layers.1.self_attn.qkv_proj':
|
||||
# The second layer is quantized using bits=8, group_size=32
|
||||
# desc_act=False
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert get_dynamic_override(config, layer_name=name,
|
||||
key="bits") == 8
|
||||
assert get_dynamic_override(config,
|
||||
layer_name=name,
|
||||
key="group_size") == 32
|
||||
assert not get_dynamic_override(
|
||||
config, layer_name=name, key="desc_act")
|
||||
elif (name == 'model.layers.2.self_attn.qkv_proj'
|
||||
or name == 'model.layers.2.mlp.gate_up_proj'):
|
||||
# All other layers (layer index >= 2) are not quantized
|
||||
assert isinstance(submodule.quant_method, UnquantizedLinearMethod)
|
||||
with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as llm:
|
||||
|
||||
del vllm_model
|
||||
def check_model(model):
|
||||
for name, submodule in model.named_modules():
|
||||
if name == "lm_head":
|
||||
assert isinstance(submodule.quant_method,
|
||||
linear_method_cls)
|
||||
elif name == 'model.layers.0.self_attn.qkv_proj':
|
||||
# The first layer is quantized using bits=4, group_size=128
|
||||
# desc_act=True
|
||||
assert isinstance(submodule.quant_method,
|
||||
linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert config.weight_bits == 4
|
||||
assert config.group_size == 128
|
||||
assert config.desc_act
|
||||
elif name == 'model.layers.1.self_attn.qkv_proj':
|
||||
# The second layer is quantized using bits=8, group_size=32
|
||||
# desc_act=False
|
||||
assert isinstance(submodule.quant_method,
|
||||
linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert get_dynamic_override(config,
|
||||
layer_name=name,
|
||||
key="bits") == 8
|
||||
assert get_dynamic_override(config,
|
||||
layer_name=name,
|
||||
key="group_size") == 32
|
||||
assert not get_dynamic_override(
|
||||
config, layer_name=name, key="desc_act")
|
||||
elif (name == 'model.layers.2.self_attn.qkv_proj'
|
||||
or name == 'model.layers.2.mlp.gate_up_proj'):
|
||||
# All other layers (layer index >= 2) are not quantized
|
||||
assert isinstance(submodule.quant_method,
|
||||
UnquantizedLinearMethod)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
@ -29,8 +29,8 @@ def test_lm_head(
|
||||
lm_head_quantized: bool,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
with vllm_runner(model_id, dtype=torch.float16,
|
||||
max_model_len=2048) as vllm_model:
|
||||
|
||||
|
||||
@ -11,16 +11,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This module relies on V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
if not current_platform.is_cpu():
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("modelopt"),
|
||||
|
||||
@ -13,6 +13,16 @@ from vllm.model_executor.layers.quantization.ptpc_fp8 import (
|
||||
PTPCFp8LinearMethod)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
UNSUPPORTED_STR = (
|
||||
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only "
|
||||
"support output dtype of bfloat16. torch.float16 is specified.")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"),
|
||||
reason="PTPC FP8 is not supported on this GPU type.")
|
||||
@ -21,14 +31,22 @@ from vllm.platforms import current_platform
|
||||
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
|
||||
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
|
||||
|
||||
try:
|
||||
with vllm_runner("facebook/opt-125m",
|
||||
dtype=dtype,
|
||||
quantization="ptpc_fp8",
|
||||
kv_cache_dtype=kv_cache_dtype) as llm:
|
||||
llm = vllm_runner("facebook/opt-125m",
|
||||
dtype=dtype,
|
||||
quantization="ptpc_fp8",
|
||||
kv_cache_dtype=kv_cache_dtype)
|
||||
except AssertionError as e:
|
||||
if str(e) == UNSUPPORTED_STR:
|
||||
# If the error message matches, the test passes
|
||||
return
|
||||
else:
|
||||
# If the error message does not match, re-raise the exception
|
||||
raise
|
||||
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
with llm:
|
||||
|
||||
def check_model(model):
|
||||
fc1 = model.model.decoder.layers[0].fc1
|
||||
assert isinstance(fc1.quant_method, PTPCFp8LinearMethod)
|
||||
if kv_cache_dtype == "ptpc_fp8":
|
||||
@ -40,17 +58,8 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
|
||||
if current_platform.has_device_capability(94):
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fnuz
|
||||
else:
|
||||
pytest.skip()
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
except AssertionError as e:
|
||||
if str(
|
||||
e
|
||||
) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501
|
||||
# If the error message matches, the test passes
|
||||
pass
|
||||
else:
|
||||
# If the error message does not match, re-raise the exception
|
||||
raise
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
||||
@ -7,10 +7,10 @@ Run `pytest tests/quantization/test_quark.py`.
|
||||
See also `tests/kernels/moe/test_mxfp4_moe.py`.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
|
||||
import huggingface_hub
|
||||
import lm_eval
|
||||
@ -24,9 +24,8 @@ from vllm.platforms import current_platform
|
||||
|
||||
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
|
||||
"quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||
|
||||
if QUARK_MXFP4_AVAILABLE:
|
||||
from quark.torch.export.nn.modules.realquantizer import (
|
||||
@ -43,11 +42,9 @@ except huggingface_hub.errors.RepositoryNotFoundError:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This module relies on V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
|
||||
@ -132,13 +129,12 @@ def test_quark_fp8_parity(vllm_runner):
|
||||
}
|
||||
with (vllm_runner(quark_model_id, **llm_kwargs) as
|
||||
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
|
||||
quark_model = (quark_handle.llm.llm_engine.model_executor.
|
||||
driver_worker.model_runner.model)
|
||||
quark_state_dict = quark_model.state_dict()
|
||||
|
||||
fp8_model = (fp8_handle.llm.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
fp8_state_dict = fp8_model.state_dict()
|
||||
def get_state_dict(model):
|
||||
return {k: v.cpu() for k, v in model.state_dict().items()}
|
||||
|
||||
quark_state_dict, = quark_handle.apply_model(get_state_dict)
|
||||
fp8_state_dict, = fp8_handle.apply_model(get_state_dict)
|
||||
|
||||
assert fp8_state_dict.keys() == quark_state_dict.keys()
|
||||
|
||||
|
||||
@ -105,18 +105,21 @@ def test_register_quantization_config():
|
||||
])
|
||||
def test_custom_quant(vllm_runner, model, monkeypatch):
|
||||
"""Test infer with the custom quantization method."""
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
with vllm_runner(model_name=model,
|
||||
quantization="custom_quant",
|
||||
enforce_eager=True) as llm:
|
||||
|
||||
model = llm.llm.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
layer = model.model.layers[0]
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
|
||||
# Check the quantization method is FakeQuantLinearMethod
|
||||
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
|
||||
# Check the quantization method is FakeQuantLinearMethod
|
||||
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
||||
Reference in New Issue
Block a user