diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 79e3e448ca..19f6c4e306 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, Union import torch from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -349,3 +350,56 @@ class ipex_ops: def swap_blocks(src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor) -> None: torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore + + @staticmethod + def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, + output: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function is designed for both static and dynamic quantization: + If you provide the scale, it will use static scaling and if you omit + it, the scale will be determined dynamically. Currently, XPU platform + only supports dynamic quantization. The function also allows optional + padding of the output tensors for downstream kernels that will benefit + from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert (input.ndim == 2) + shape: Union[tuple[int, int], torch.Size] = input.shape + out_dtype: torch.dtype = current_platform.fp8_dtype() + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, \ + "padding not supported if output passed in" + assert output.dtype == out_dtype + assert scale is None, "only dynamic fp8 quantization supported on XPU" + assert not use_per_token_if_dynamic, ( + "per token dynamic fp8 quantization not supported on XPU") + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale) + + return output, scale diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 48bac8697e..d9e01dcf40 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -137,10 +137,35 @@ class Fp8Config(QuantizationConfig): ignored_layers=ignored_layers, weight_block_size=weight_block_size) + def get_xpu_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention + from vllm.model_executor.layers.quantization.ipex_quant import ( + XPUFp8LinearMethod, XPUFp8MoEMethod) + fp8_config = Fp8Config( + is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized, + activation_scheme=self.activation_scheme, + ignored_layers=self.ignored_layers, + weight_block_size=self.weight_block_size) + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping): + return UnquantizedLinearMethod() + return XPUFp8LinearMethod(fp8_config) + elif isinstance(layer, FusedMoE): + return XPUFp8MoEMethod(fp8_config, layer) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import + if current_platform.is_xpu(): + return self.get_xpu_quant_method(layer, prefix) if isinstance(layer, LinearBase): if is_layer_skipped(prefix=prefix, ignored_layers=self.ignored_layers, diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 9c458954f9..5f9d481427 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -1,11 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Any, Callable, Optional import torch from packaging import version +from torch.nn import Module +from torch.nn.parameter import Parameter +from vllm._ipex_ops import ipex_ops as ops +from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, + FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -13,7 +18,10 @@ from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod, is_layer_skipped_awq) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, + Fp8LinearMethod) from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod +from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform MIN_IPEX_VERSION = "2.6.0" @@ -251,3 +259,152 @@ class IPEXAWQLinearMethod(AWQLinearMethod): reshaped_x = x.reshape(-1, x.shape[-1]) out = layer.ipex_qlinear(reshaped_x) return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) + + +class XPUFp8LinearMethod(Fp8LinearMethod): + + def __init__(self, quant_config: Fp8Config): + super().__init__(quant_config) + + def process_weights_after_loading(self, layer: Module) -> None: + # If checkpoint not serialized fp8, quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, + scale=None) + # Update the layer with the new values. + layer.weight = Parameter(qweight, requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight.data + weight_scale = layer.weight_scale.data + output = torch.ops.torch_ipex.fp8_gemm_w8a16(x, weight, True, + weight_scale, bias) + return output + + +class XPUFp8MoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): + super().__init__(layer.moe_config) + self.quant_config = quant_config + + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + # INPUT_SCALES + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + if not self.quant_config.is_checkpoint_fp8_serialized: + fp8_dtype = current_platform.fp8_dtype() + w13_weight = torch.empty_like(layer.w13_weight.data, + dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( + layer.local_num_experts, + dtype=torch.float32, + device=w13_weight.device), + requires_grad=False) + for expert in range(layer.local_num_experts): + w13_weight[expert, :, :], layer.w13_weight_scale[ + expert] = ops.scaled_fp8_quant( + layer.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], layer.w2_weight_scale[ + expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :]) + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + w1_scale_inv=layer.w13_weight_scale, + w2_scale_inv=layer.w2_weight_scale, + a1_scale_inv=layer.w13_input_scale, + a2_scale_inv=layer.w2_input_scale, + use_prepack=True, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return layer.ipex_fusion( + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + custom_routing_function=custom_routing_function, + ) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 84f4cd7256..d61b921e19 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -148,6 +148,10 @@ class XPUPlatform(Platform): torch.xpu.reset_peak_memory_stats(device) return torch.xpu.max_memory_allocated(device) + @classmethod + def fp8_dtype(cls) -> torch.dtype: + return torch.float8_e5m2 + @classmethod def is_data_center_gpu(cls) -> bool: device_name = cls.get_device_name().lower()