[Misc] Refactor linear layer weight loading; introduce BasevLLMParameter and weight_loader_v2 (#5874)

This commit is contained in:
Dipika Sikka
2024-08-07 12:17:58 -04:00
committed by GitHub
parent 639159b2a6
commit 0f7052bc7e
11 changed files with 653 additions and 201 deletions

View File

@ -9,7 +9,7 @@ import torch
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsWNA16)
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationType)
@ -109,7 +109,7 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.weight_packed.pack_factor == pack_factor
assert qkv_proj.scheme.pack_factor == pack_factor
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
@ -140,13 +140,17 @@ def test_compressed_tensors_fp8(vllm_runner):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert isinstance(
qkv_proj.scheme,
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))
assert qkv_proj.input_scale.dtype is torch.float32
assert qkv_proj.weight_scale.dtype is torch.float32
# should be scalars after processing
assert len(qkv_proj.input_scale.shape) == 0
assert len(qkv_proj.weight_scale.shape) == 0
if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight_scale.shape) == 0
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output