[Bugfix][CPU] Fallback oneDNN linear to torch linear to fix half gemm support on legecy platforms (#27526)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Li, Jiang
2025-10-28 14:25:44 +08:00
committed by GitHub
parent bdb01a38fe
commit d34f5fe939
2 changed files with 22 additions and 10 deletions

View File

@ -79,7 +79,7 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc
######################### BUILD IMAGE #########################
FROM base AS vllm-build
ARG max_jobs=2
ARG max_jobs=32
ENV MAX_JOBS=${max_jobs}
ARG GIT_REPO_CHECK=0

View File

@ -8,9 +8,12 @@ import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
# Shuffle weight along the last dimension so that
@ -178,19 +181,28 @@ def dispatch_cpu_unquantized_gemm(
)
if remove_weight:
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
return
elif (
ops._supports_onednn
and current_platform.get_cpu_architecture() != CpuArchEnum.POWERPC
):
origin_weight = layer.weight
if remove_weight:
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
handler = ops.create_onednn_mm(origin_weight.t(), 32)
layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias)
else:
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
x, weight, bias
)
try:
origin_weight = layer.weight
handler = ops.create_onednn_mm(origin_weight.t(), 32)
layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias)
if remove_weight:
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
return
except RuntimeError as e:
logger.warning_once(
"Failed to create oneDNN linear, fallback to torch linear."
f" Exception: {e}"
)
# fallback case
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
x, weight, bias
)
def cpu_unquantized_gemm(