Compare commits
4 Commits
zhuohan/re
...
mla-suppor
| Author | SHA1 | Date | |
|---|---|---|---|
| 243408b6b4 | |||
| b8510f1081 | |||
| 09318caeba | |||
| d56ef8b685 |
@ -215,7 +215,7 @@ def rms_norm_dynamic_per_token_quant(
|
|||||||
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
|
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
|
||||||
zeros: torch.Tensor, split_k_iters: int, thx: int,
|
zeros: torch.Tensor, split_k_iters: int, thx: int,
|
||||||
thy: int) -> torch.Tensor:
|
thy: int) -> torch.Tensor:
|
||||||
if envs.VLLM_USE_TRITON_AWQ:
|
if envs.VLLM_USE_TRITON_AWQ or qweight.dtype != torch.float16:
|
||||||
from vllm.model_executor.layers.quantization.awq_triton import (
|
from vllm.model_executor.layers.quantization.awq_triton import (
|
||||||
awq_dequantize_triton)
|
awq_dequantize_triton)
|
||||||
return awq_dequantize_triton(qweight, scales, zeros)
|
return awq_dequantize_triton(qweight, scales, zeros)
|
||||||
|
|||||||
@ -18,6 +18,8 @@ from vllm.distributed import (get_tensor_model_parallel_world_size,
|
|||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearBase, RowParallelLinear,
|
LinearBase, RowParallelLinear,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||||
|
AWQMarlinLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||||
CompressedTensorsLinearMethod)
|
CompressedTensorsLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
@ -227,8 +229,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
|
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
|
||||||
|
|
||||||
def quantization_scheme_supported(layer: LinearBase) -> bool:
|
def quantization_scheme_supported(layer: LinearBase) -> bool:
|
||||||
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
|
return isinstance(layer.quant_method,
|
||||||
is_layer_fp8(layer)
|
(UnquantizedLinearMethod,
|
||||||
|
AWQMarlinLinearMethod)) or is_layer_fp8(layer)
|
||||||
|
|
||||||
# TODO(lucas) This is very gross, we need a more wide scale refactor of
|
# TODO(lucas) This is very gross, we need a more wide scale refactor of
|
||||||
# all the FP8 code with a more standard way of
|
# all the FP8 code with a more standard way of
|
||||||
@ -289,6 +292,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
|
|
||||||
return scaled_dequantize(weight, scales,
|
return scaled_dequantize(weight, scales,
|
||||||
weight_scale_group_shape)
|
weight_scale_group_shape)
|
||||||
|
elif isinstance(layer.quant_method, AWQMarlinLinearMethod):
|
||||||
|
return layer.quant_method.decompress_weights(layer).T
|
||||||
else:
|
else:
|
||||||
return layer.weight
|
return layer.weight
|
||||||
|
|
||||||
@ -296,12 +301,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
quantization_scheme_supported(self.q_proj) and\
|
quantization_scheme_supported(self.q_proj) and\
|
||||||
quantization_scheme_supported(self.o_proj)):
|
quantization_scheme_supported(self.o_proj)):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
|
"Only FP8, AWQ, and Unquantized are supported for MLA"
|
||||||
", please run with VLLM_MLA_DISABLE=1")
|
", please run with VLLM_MLA_DISABLE=1")
|
||||||
|
|
||||||
weight_dtype = self.kv_b_proj.weight.dtype
|
def get_layer_dtype(layer):
|
||||||
assert self.o_proj.weight.dtype == weight_dtype
|
if hasattr(layer, "weight"):
|
||||||
assert self.q_proj.weight.dtype == weight_dtype
|
return layer.weight.dtype
|
||||||
|
elif hasattr(layer, "qweight"):
|
||||||
|
return layer.qweight.dtype
|
||||||
|
else:
|
||||||
|
raise AttributeError(
|
||||||
|
f"Layer '{layer}' has neither weight nor qweight")
|
||||||
|
|
||||||
|
weight_dtype = get_layer_dtype(self.kv_b_proj)
|
||||||
|
assert get_layer_dtype(self.o_proj) == weight_dtype
|
||||||
|
assert get_layer_dtype(self.q_proj) == weight_dtype
|
||||||
|
|
||||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||||
assert kv_b_proj_weight.shape == (
|
assert kv_b_proj_weight.shape == (
|
||||||
|
|||||||
@ -990,7 +990,7 @@ class ModelConfig:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if self.quantization is not None and self.quantization not in [\
|
if self.quantization is not None and self.quantization not in [\
|
||||||
"fp8", "compressed-tensors"]:
|
"fp8", "compressed-tensors", "awq_marlin", "moe_wna16"]:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"MLA is not supported with %s quantization. "
|
"MLA is not supported with %s quantization. "
|
||||||
"Disabling MLA.", self.quantization)
|
"Disabling MLA.", self.quantization)
|
||||||
|
|||||||
@ -242,6 +242,16 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|||||||
layer.output_size_per_partition = output_size_per_partition
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
layer.num_groups = num_groups
|
layer.num_groups = num_groups
|
||||||
|
|
||||||
|
def decompress_weights(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Decompress to recover the original unquantized weight.
|
||||||
|
NOTE: this is only to be used before process_weights_after_loading
|
||||||
|
"""
|
||||||
|
# We can use AWQ's dequant since the unprocessed weights
|
||||||
|
# are in AWQ format
|
||||||
|
return ops.awq_dequantize(layer.qweight, layer.scales, layer.qzeros, 0,
|
||||||
|
0, 0)
|
||||||
|
|
||||||
# TODO: Update this docs
|
# TODO: Update this docs
|
||||||
# Checkpoints are serialized in AutoAWQ format, which is different from the
|
# Checkpoints are serialized in AutoAWQ format, which is different from the
|
||||||
# marlin format. This function is called after the weights are loaded.
|
# marlin format. This function is called after the weights are loaded.
|
||||||
|
|||||||
@ -153,6 +153,30 @@ def _initialize_model(
|
|||||||
return model_class(**kwargs)
|
return model_class(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
|
||||||
|
target_device: torch.device) -> None:
|
||||||
|
# Currently only used by MLA.
|
||||||
|
# NOTE: This intentionally happens before other modules so we can easily
|
||||||
|
# decompress the weights for MLA.
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
if isinstance(module, Attention) and \
|
||||||
|
hasattr(module, "process_weights_after_loading"):
|
||||||
|
# TODO(lucas): see if there is a way to unify the signatures
|
||||||
|
# of process_weights_after_loading
|
||||||
|
module.process_weights_after_loading(model_config.dtype)
|
||||||
|
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if isinstance(quant_method, QuantizeMethodBase):
|
||||||
|
# When quant methods need to process weights after loading
|
||||||
|
# (for repacking, quantizing, etc), they expect parameters
|
||||||
|
# to be on the global target device. This scope is for the
|
||||||
|
# case where cpu offloading is used, where we will move the
|
||||||
|
# parameters onto device for processing and back off after.
|
||||||
|
with device_loading_context(module, target_device):
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
|
|
||||||
|
|
||||||
class BaseModelLoader(ABC):
|
class BaseModelLoader(ABC):
|
||||||
"""Base class for model loaders."""
|
"""Base class for model loaders."""
|
||||||
|
|
||||||
@ -376,7 +400,6 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
|
|
||||||
target_device = torch.device(device_config.device)
|
target_device = torch.device(device_config.device)
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with target_device:
|
with target_device:
|
||||||
@ -394,23 +417,8 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
"Following weights were not initialized from "
|
"Following weights were not initialized from "
|
||||||
f"checkpoint: {weights_not_loaded}")
|
f"checkpoint: {weights_not_loaded}")
|
||||||
|
|
||||||
for _, module in model.named_modules():
|
_process_weights_after_loading(model, model_config, target_device)
|
||||||
quant_method = getattr(module, "quant_method", None)
|
|
||||||
if isinstance(quant_method, QuantizeMethodBase):
|
|
||||||
# When quant methods need to process weights after loading
|
|
||||||
# (for repacking, quantizing, etc), they expect parameters
|
|
||||||
# to be on the global target device. This scope is for the
|
|
||||||
# case where cpu offloading is used, where we will move the
|
|
||||||
# parameters onto device for processing and back off after.
|
|
||||||
with device_loading_context(module, target_device):
|
|
||||||
quant_method.process_weights_after_loading(module)
|
|
||||||
if isinstance(module, Attention) and \
|
|
||||||
hasattr(module, "process_weights_after_loading"):
|
|
||||||
# When attention modules need to process weights after
|
|
||||||
# currently only used by MLA
|
|
||||||
# TODO(lucas): see if there is a way to unify the signatures
|
|
||||||
# of process_weights_after_loading
|
|
||||||
module.process_weights_after_loading(model_config.dtype)
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
@ -429,29 +437,15 @@ class DummyModelLoader(BaseModelLoader):
|
|||||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
|
target_device = torch.device(device_config.device)
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with target_device:
|
||||||
model = _initialize_model(vllm_config=vllm_config)
|
model = _initialize_model(vllm_config=vllm_config)
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
initialize_dummy_weights(model)
|
initialize_dummy_weights(model)
|
||||||
|
|
||||||
for _, module in model.named_modules():
|
_process_weights_after_loading(model, model_config, target_device)
|
||||||
quant_method = getattr(module, "quant_method", None)
|
|
||||||
if quant_method is not None:
|
|
||||||
# When quant methods need to process weights after loading
|
|
||||||
# (for repacking, quantizing, etc), they expect parameters
|
|
||||||
# to be on the global target device. This scope is for the
|
|
||||||
# case where cpu offloading is used, where we will move the
|
|
||||||
# parameters onto device for processing and back off after.
|
|
||||||
with device_loading_context(
|
|
||||||
module, torch.device(device_config.device)):
|
|
||||||
quant_method.process_weights_after_loading(module)
|
|
||||||
if isinstance(module, Attention) and \
|
|
||||||
hasattr(module, "process_weights_after_loading"):
|
|
||||||
# When attention modules need to process weights after
|
|
||||||
# currently only used by MLA
|
|
||||||
module.process_weights_after_loading(model_config.dtype)
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
@ -632,6 +626,7 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
|
target_device = torch.device(device_config.device)
|
||||||
from safetensors.torch import safe_open
|
from safetensors.torch import safe_open
|
||||||
|
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
@ -640,18 +635,10 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
model_config.revision)
|
model_config.revision)
|
||||||
|
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with target_device:
|
||||||
model = _initialize_model(vllm_config=vllm_config)
|
model = _initialize_model(vllm_config=vllm_config)
|
||||||
for _, module in model.named_modules():
|
_process_weights_after_loading(model, model_config,
|
||||||
quant_method = getattr(module, "quant_method", None)
|
target_device)
|
||||||
if quant_method is not None:
|
|
||||||
quant_method.process_weights_after_loading(module)
|
|
||||||
if isinstance(module, Attention) and \
|
|
||||||
hasattr(module, "process_weights_after_loading"):
|
|
||||||
# When attention modules need to process weights after
|
|
||||||
# currently only used by MLA
|
|
||||||
module.process_weights_after_loading(
|
|
||||||
model_config.dtype)
|
|
||||||
rank = get_tensor_model_parallel_rank()
|
rank = get_tensor_model_parallel_rank()
|
||||||
pattern = os.path.join(
|
pattern = os.path.join(
|
||||||
local_model_path,
|
local_model_path,
|
||||||
@ -1401,16 +1388,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
|||||||
self._get_weights_iterator(model_weights,
|
self._get_weights_iterator(model_weights,
|
||||||
model_config.revision))
|
model_config.revision))
|
||||||
|
|
||||||
for _, module in model.named_modules():
|
_process_weights_after_loading(model, model_config, target_device)
|
||||||
quant_method = getattr(module, "quant_method", None)
|
|
||||||
if quant_method is not None:
|
|
||||||
with device_loading_context(module, target_device):
|
|
||||||
quant_method.process_weights_after_loading(module)
|
|
||||||
if isinstance(module, Attention) and \
|
|
||||||
hasattr(module, "process_weights_after_loading"):
|
|
||||||
# When attention modules need to process weights after
|
|
||||||
# currently only used by MLA
|
|
||||||
module.process_weights_after_loading(model_config.dtype)
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user