[quantization] use channel scales for w4a8 + misc fixes (#23570)

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere
2025-08-26 21:23:23 -04:00
committed by GitHub
parent c7c80af084
commit 2c2b140ae8
4 changed files with 63 additions and 14 deletions

View File

@ -14,10 +14,10 @@ from compressed_tensors.quantization import QuantizationType
from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensors24, CompressedTensorsLinearMethod,
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
CompressedTensorsW4A4Fp4, CompressedTensorsW4A8Fp8,
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@ -683,3 +683,39 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
@pytest.mark.skipif(
not current_platform.is_cuda()
or not current_platform.has_device_capability(90),
reason="W4A8 FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize("args", [
("czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e", CompressedTensorsW4A8Fp8)
])
def test_compressed_tensors_w4a8_fp8(vllm_runner, args):
model, scheme = args
with vllm_runner(model, enforce_eager=True) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
o_proj = layer.self_attn.o_proj
gate_up_proj = layer.mlp.gate_up_proj
down_proj = layer.mlp.down_proj
for proj in (qkv_proj, o_proj, gate_up_proj, down_proj):
assert isinstance(proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(proj.scheme, scheme)
assert proj.weight_packed.dtype is torch.int32
assert proj.weight_scale.dtype is torch.float8_e4m3fn
assert proj.weight_chan_scale.dtype is torch.float32
assert proj.scheme.group_size == 128
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output