[quantization] use channel scales for w4a8 + misc fixes (#23570)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
@ -14,10 +14,10 @@ from compressed_tensors.quantization import QuantizationType
|
||||
from tests.models.utils import check_logprobs_close
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensors24, CompressedTensorsLinearMethod,
|
||||
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
|
||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16)
|
||||
CompressedTensorsW4A4Fp4, CompressedTensorsW4A8Fp8,
|
||||
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
@ -683,3 +683,39 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda()
|
||||
or not current_platform.has_device_capability(90),
|
||||
reason="W4A8 FP8 is not yet supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("args", [
|
||||
("czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e", CompressedTensorsW4A8Fp8)
|
||||
])
|
||||
def test_compressed_tensors_w4a8_fp8(vllm_runner, args):
|
||||
model, scheme = args
|
||||
with vllm_runner(model, enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
o_proj = layer.self_attn.o_proj
|
||||
gate_up_proj = layer.mlp.gate_up_proj
|
||||
down_proj = layer.mlp.down_proj
|
||||
|
||||
for proj in (qkv_proj, o_proj, gate_up_proj, down_proj):
|
||||
assert isinstance(proj.quant_method,
|
||||
CompressedTensorsLinearMethod)
|
||||
assert isinstance(proj.scheme, scheme)
|
||||
|
||||
assert proj.weight_packed.dtype is torch.int32
|
||||
assert proj.weight_scale.dtype is torch.float8_e4m3fn
|
||||
assert proj.weight_chan_scale.dtype is torch.float32
|
||||
assert proj.scheme.group_size == 128
|
||||
|
||||
llm.apply_model(check_model)
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
Reference in New Issue
Block a user