[XPU][Feature] fp8 online quantization support for XPU (#23148)
Signed-off-by: Yan Ma <yan.ma@intel.com> Co-authored-by: Qiming Zhang <qiming1.zhang@intel.com>
This commit is contained in:
@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -349,3 +350,56 @@ class ipex_ops:
|
||||
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
scale_ub: Optional[torch.Tensor] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||
|
||||
This function is designed for both static and dynamic quantization:
|
||||
If you provide the scale, it will use static scaling and if you omit
|
||||
it, the scale will be determined dynamically. Currently, XPU platform
|
||||
only supports dynamic quantization. The function also allows optional
|
||||
padding of the output tensors for downstream kernels that will benefit
|
||||
from padding.
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP8
|
||||
scale: Optional scaling factor for the FP8 quantization
|
||||
scale_ub: Optional upper bound for scaling factor in dynamic
|
||||
per token case
|
||||
num_token_padding: If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||
in the dynamic quantization case.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
||||
scaling factor.
|
||||
"""
|
||||
# This code assumes batch_dim and num_tokens are flattened
|
||||
assert (input.ndim == 2)
|
||||
shape: Union[tuple[int, int], torch.Size] = input.shape
|
||||
out_dtype: torch.dtype = current_platform.fp8_dtype()
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
if output is None:
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
else:
|
||||
assert num_token_padding is None, \
|
||||
"padding not supported if output passed in"
|
||||
assert output.dtype == out_dtype
|
||||
assert scale is None, "only dynamic fp8 quantization supported on XPU"
|
||||
assert not use_per_token_if_dynamic, (
|
||||
"per token dynamic fp8 quantization not supported on XPU")
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
||||
|
||||
@ -137,10 +137,35 @@ class Fp8Config(QuantizationConfig):
|
||||
ignored_layers=ignored_layers,
|
||||
weight_block_size=weight_block_size)
|
||||
|
||||
def get_xpu_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.model_executor.layers.quantization.ipex_quant import (
|
||||
XPUFp8LinearMethod, XPUFp8MoEMethod)
|
||||
fp8_config = Fp8Config(
|
||||
is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
|
||||
activation_scheme=self.activation_scheme,
|
||||
ignored_layers=self.ignored_layers,
|
||||
weight_block_size=self.weight_block_size)
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
return XPUFp8LinearMethod(fp8_config)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return XPUFp8MoEMethod(fp8_config, layer)
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if current_platform.is_xpu():
|
||||
return self.get_xpu_quant_method(layer, prefix)
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -13,7 +18,10 @@ from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
|
||||
is_layer_skipped_awq)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
||||
Fp8LinearMethod)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MIN_IPEX_VERSION = "2.6.0"
|
||||
@ -251,3 +259,152 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = layer.ipex_qlinear(reshaped_x)
|
||||
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
|
||||
|
||||
|
||||
class XPUFp8LinearMethod(Fp8LinearMethod):
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
super().__init__(quant_config)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||
scale=None)
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.input_scale = None
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
weight = layer.weight.data
|
||||
weight_scale = layer.weight_scale.data
|
||||
output = torch.ops.torch_ipex.fp8_gemm_w8a16(x, weight, True,
|
||||
weight_scale, bias)
|
||||
return output
|
||||
|
||||
|
||||
class XPUFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
super().__init__(layer.moe_config)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
2,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
# INPUT_SCALES
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data,
|
||||
dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
layer.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device),
|
||||
requires_grad=False)
|
||||
for expert in range(layer.local_num_experts):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
layer.w13_weight.data[expert, :, :])
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
layer.w2_weight.data[expert, :, :])
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
import intel_extension_for_pytorch as ipex
|
||||
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
w1_scale_inv=layer.w13_weight_scale,
|
||||
w2_scale_inv=layer.w2_weight_scale,
|
||||
a1_scale_inv=layer.w13_input_scale,
|
||||
a2_scale_inv=layer.w2_input_scale,
|
||||
use_prepack=True,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return layer.ipex_fusion(
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
)
|
||||
|
||||
@ -148,6 +148,10 @@ class XPUPlatform(Platform):
|
||||
torch.xpu.reset_peak_memory_stats(device)
|
||||
return torch.xpu.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def fp8_dtype(cls) -> torch.dtype:
|
||||
return torch.float8_e5m2
|
||||
|
||||
@classmethod
|
||||
def is_data_center_gpu(cls) -> bool:
|
||||
device_name = cls.get_device_name().lower()
|
||||
|
||||
Reference in New Issue
Block a user