Compare commits

...

6 Commits

View File

@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
@ -94,11 +93,9 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
)
from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.import_utils import has_deep_gemm
@ -546,83 +543,19 @@ class Fp8LinearMethod(LinearMethodBase):
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
if vllm_is_batch_invariant():
# Call is_deep_gemm_supported() ahead of time for torch.compile
# dynamo has trouble tracing through
if self.block_quant and should_use_deepgemm_for_fp8_linear(
torch.bfloat16, layer.weight, self.use_deep_gemm
):
# use group quant consistent with block size across K
assert self.act_q_group_shape is not None
q_input, input_scale = QuantFP8(
False,
self.act_q_group_shape,
column_major_scales=True,
)(x)
output_2d = torch.empty(
(q_input.shape[0], layer.weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device,
)
fp8_gemm_nt(
(q_input, input_scale),
(layer.weight, layer.weight_scale),
output_2d,
)
if bias is not None:
output_2d = output_2d + bias
return output_2d
# Dequantize FP8 weights to BF16
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
# Handle different quantization granularities
if self.block_quant:
# Block-wise quantization:
# - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
# - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
assert self.weight_block_size is not None
block_n, block_k = self.weight_block_size # Note: order is [N, K]
N, K = weight_fp8.shape
# determine expected number of blocks along N and K
num_blocks_n = (N + block_n - 1) // block_n
num_blocks_k = (K + block_k - 1) // block_k
# scale layout may be [num_blocks_n, num_blocks_k]
# or [num_blocks_k, num_blocks_n] depending on backend
if weight_scale.dim() != 2:
raise RuntimeError(
f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}"
)
scale_rows, scale_cols = weight_scale.shape
if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n):
if num_blocks_n == num_blocks_k:
# ambiguous square case, warn and skip transpose
logger.warning(
"Batch-invariant FP8: square block-scale %dx%d; "
"skipping transpose to avoid misorientation.",
scale_rows,
scale_cols,
)
else:
# clear KN -> transpose to NK
weight_scale = weight_scale.t()
# Expand scale to match weight dimensions
# scale_expanded should have shape [N, K]
scale_expanded = weight_scale.repeat_interleave(
block_n, dim=0
).repeat_interleave(block_k, dim=1)
# Trim to exact weight size (in case of padding)
scale_expanded = scale_expanded[:N, :K]
weight_bf16 = weight_fp8 * scale_expanded
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
else:
# Per-tensor quantization: weight IS transposed to [K, N]
# scale should be scalar or [1] or per-output-channel [N]
# per-tensor/channel: dequant to BF16 and run GEMM
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
if weight_scale.numel() == 1:
# Per-tensor: simple scalar multiplication
weight_bf16 = weight_fp8 * weight_scale
@ -641,16 +574,7 @@ class Fp8LinearMethod(LinearMethodBase):
else:
# Fallback
weight_bf16 = weight_fp8 * weight_scale
# For block quant, weight is [N, K], for per-tensor it's [K, N]
# F.linear expects weight to be [N, K], so:
if self.block_quant:
# Already in correct shape [N, K]
output = torch.nn.functional.linear(x, weight_bf16, bias)
else:
# Need to transpose back: [K, N] -> [N, K]
output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
return output
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
if self.use_marlin:
return apply_fp8_marlin_linear(