[Bugfix][CPU] Fallback oneDNN linear to torch linear to fix half gemm support on legecy platforms (#27526)
Signed-off-by: jiang1.li <jiang1.li@intel.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@ -79,7 +79,7 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc
|
||||
######################### BUILD IMAGE #########################
|
||||
FROM base AS vllm-build
|
||||
|
||||
ARG max_jobs=2
|
||||
ARG max_jobs=32
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
|
||||
ARG GIT_REPO_CHECK=0
|
||||
|
||||
@ -8,9 +8,12 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
|
||||
# Shuffle weight along the last dimension so that
|
||||
@ -178,19 +181,28 @@ def dispatch_cpu_unquantized_gemm(
|
||||
)
|
||||
if remove_weight:
|
||||
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
|
||||
return
|
||||
elif (
|
||||
ops._supports_onednn
|
||||
and current_platform.get_cpu_architecture() != CpuArchEnum.POWERPC
|
||||
):
|
||||
origin_weight = layer.weight
|
||||
if remove_weight:
|
||||
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
|
||||
handler = ops.create_onednn_mm(origin_weight.t(), 32)
|
||||
layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias)
|
||||
else:
|
||||
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
|
||||
x, weight, bias
|
||||
)
|
||||
try:
|
||||
origin_weight = layer.weight
|
||||
handler = ops.create_onednn_mm(origin_weight.t(), 32)
|
||||
layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias)
|
||||
if remove_weight:
|
||||
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
|
||||
return
|
||||
except RuntimeError as e:
|
||||
logger.warning_once(
|
||||
"Failed to create oneDNN linear, fallback to torch linear."
|
||||
f" Exception: {e}"
|
||||
)
|
||||
|
||||
# fallback case
|
||||
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
|
||||
x, weight, bias
|
||||
)
|
||||
|
||||
|
||||
def cpu_unquantized_gemm(
|
||||
|
||||
Reference in New Issue
Block a user