[XPU][Feature] fp8 online quantization support for XPU (#23148)

Signed-off-by: Yan Ma <yan.ma@intel.com>
Co-authored-by: Qiming Zhang <qiming1.zhang@intel.com>
This commit is contained in:
Yan Ma
2025-09-02 12:06:53 +08:00
committed by GitHub
parent 1fa1d6a9a0
commit 7be0cb8e9e
4 changed files with 242 additions and 2 deletions

View File

@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Optional, Union
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -349,3 +350,56 @@ class ipex_ops:
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
@staticmethod
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function is designed for both static and dynamic quantization:
If you provide the scale, it will use static scaling and if you omit
it, the scale will be determined dynamically. Currently, XPU platform
only supports dynamic quantization. The function also allows optional
padding of the output tensors for downstream kernels that will benefit
from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
assert scale is None, "only dynamic fp8 quantization supported on XPU"
assert not use_per_token_if_dynamic, (
"per token dynamic fp8 quantization not supported on XPU")
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale)
return output, scale

View File

@ -137,10 +137,35 @@ class Fp8Config(QuantizationConfig):
ignored_layers=ignored_layers,
weight_block_size=weight_block_size)
def get_xpu_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention
from vllm.model_executor.layers.quantization.ipex_quant import (
XPUFp8LinearMethod, XPUFp8MoEMethod)
fp8_config = Fp8Config(
is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
activation_scheme=self.activation_scheme,
ignored_layers=self.ignored_layers,
weight_block_size=self.weight_block_size)
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
return XPUFp8LinearMethod(fp8_config)
elif isinstance(layer, FusedMoE):
return XPUFp8MoEMethod(fp8_config, layer)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if current_platform.is_xpu():
return self.get_xpu_quant_method(layer, prefix)
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix=prefix,
ignored_layers=self.ignored_layers,

View File

@ -1,11 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Any, Callable, Optional
import torch
from packaging import version
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -13,7 +18,10 @@ from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
Fp8LinearMethod)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
MIN_IPEX_VERSION = "2.6.0"
@ -251,3 +259,152 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x)
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
class XPUFp8LinearMethod(Fp8LinearMethod):
def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config)
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
# Update the layer with the new values.
layer.weight = Parameter(qweight, requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight.data
weight_scale = layer.weight_scale.data
output = torch.ops.torch_ipex.fp8_gemm_w8a16(x, weight, True,
weight_scale, bias)
return output
class XPUFp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(layer.moe_config)
self.quant_config = quant_config
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# INPUT_SCALES
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if not self.quant_config.is_checkpoint_fp8_serialized:
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
layer.local_num_experts,
dtype=torch.float32,
device=w13_weight.device),
requires_grad=False)
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[
expert] = ops.scaled_fp8_quant(
layer.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], layer.w2_weight_scale[
expert] = ops.scaled_fp8_quant(
layer.w2_weight.data[expert, :, :])
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
w1_scale_inv=layer.w13_weight_scale,
w2_scale_inv=layer.w2_weight_scale,
a1_scale_inv=layer.w13_input_scale,
a2_scale_inv=layer.w2_input_scale,
use_prepack=True,
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
custom_routing_function=custom_routing_function,
)

View File

@ -148,6 +148,10 @@ class XPUPlatform(Platform):
torch.xpu.reset_peak_memory_stats(device)
return torch.xpu.max_memory_allocated(device)
@classmethod
def fp8_dtype(cls) -> torch.dtype:
return torch.float8_e5m2
@classmethod
def is_data_center_gpu(cls) -> bool:
device_name = cls.get_device_name().lower()