diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index f3fd1ee3e3..adaf8a3c5b 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -79,7 +79,7 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc ######################### BUILD IMAGE ######################### FROM base AS vllm-build -ARG max_jobs=2 +ARG max_jobs=32 ENV MAX_JOBS=${max_jobs} ARG GIT_REPO_CHECK=0 diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index e6b6a70afd..da5eea02d1 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -8,9 +8,12 @@ import torch from vllm import _custom_ops as ops from vllm import envs +from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.utils.torch_utils import direct_register_custom_op +logger = init_logger(__name__) + def shuffle_weight(w: torch.Tensor) -> torch.Tensor: # Shuffle weight along the last dimension so that @@ -178,19 +181,28 @@ def dispatch_cpu_unquantized_gemm( ) if remove_weight: layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) + return elif ( ops._supports_onednn and current_platform.get_cpu_architecture() != CpuArchEnum.POWERPC ): - origin_weight = layer.weight - if remove_weight: - layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) - handler = ops.create_onednn_mm(origin_weight.t(), 32) - layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias) - else: - layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( - x, weight, bias - ) + try: + origin_weight = layer.weight + handler = ops.create_onednn_mm(origin_weight.t(), 32) + layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias) + if remove_weight: + layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) + return + except RuntimeError as e: + logger.warning_once( + "Failed to create oneDNN linear, fallback to torch linear." + f" Exception: {e}" + ) + + # fallback case + layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( + x, weight, bias + ) def cpu_unquantized_gemm(