[Kernel] Centralize platform kernel import in current_platform.import_kernels (#26286)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-10-08 22:25:31 +02:00
committed by GitHub
parent e1ba235668
commit 4ebc9108a7
4 changed files with 13 additions and 15 deletions

View File

@ -12,8 +12,7 @@ from vllm.scalar_type import ScalarType
logger = init_logger(__name__)
current_platform.import_core_kernels()
supports_moe_ops = current_platform.try_import_moe_kernels()
current_platform.import_kernels()
if TYPE_CHECKING:
@ -1921,7 +1920,7 @@ def moe_wna16_marlin_gemm(
)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@register_fake("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(

View File

@ -170,22 +170,15 @@ class Platform:
return device_id
@classmethod
def import_core_kernels(cls) -> None:
def import_kernels(cls) -> None:
"""Import any platform-specific C kernels."""
try:
import vllm._C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._C: %r", e)
@classmethod
def try_import_moe_kernels(cls) -> bool:
"""Import any platform-specific MoE kernels."""
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
return True
return False
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from typing import TYPE_CHECKING, Optional, Union, cast
import torch
@ -45,8 +46,10 @@ class TpuPlatform(Platform):
additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
@classmethod
def import_core_kernels(cls) -> None:
pass
def import_kernels(cls) -> None:
# Do not import vllm._C
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
@classmethod
def get_attn_backend_cls(

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
from typing import TYPE_CHECKING, Optional
@ -35,8 +36,10 @@ class XPUPlatform(Platform):
device_control_env_var: str = "ZE_AFFINITY_MASK"
@classmethod
def import_core_kernels(cls) -> None:
pass
def import_kernels(cls) -> None:
# Do not import vllm._C
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
@classmethod
def get_attn_backend_cls(