[Kernel] Centralize platform kernel import in current_platform.import_kernels (#26286)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@ -12,8 +12,7 @@ from vllm.scalar_type import ScalarType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
current_platform.import_core_kernels()
|
||||
supports_moe_ops = current_platform.try_import_moe_kernels()
|
||||
current_platform.import_kernels()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@ -1921,7 +1920,7 @@ def moe_wna16_marlin_gemm(
|
||||
)
|
||||
|
||||
|
||||
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
||||
if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
||||
|
||||
@register_fake("_moe_C::marlin_gemm_moe")
|
||||
def marlin_gemm_moe_fake(
|
||||
|
||||
@ -170,22 +170,15 @@ class Platform:
|
||||
return device_id
|
||||
|
||||
@classmethod
|
||||
def import_core_kernels(cls) -> None:
|
||||
def import_kernels(cls) -> None:
|
||||
"""Import any platform-specific C kernels."""
|
||||
try:
|
||||
import vllm._C # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C: %r", e)
|
||||
|
||||
@classmethod
|
||||
def try_import_moe_kernels(cls) -> bool:
|
||||
"""Import any platform-specific MoE kernels."""
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C # noqa: F401
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
from typing import TYPE_CHECKING, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
@ -45,8 +46,10 @@ class TpuPlatform(Platform):
|
||||
additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
|
||||
|
||||
@classmethod
|
||||
def import_core_kernels(cls) -> None:
|
||||
pass
|
||||
def import_kernels(cls) -> None:
|
||||
# Do not import vllm._C
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
@ -35,8 +36,10 @@ class XPUPlatform(Platform):
|
||||
device_control_env_var: str = "ZE_AFFINITY_MASK"
|
||||
|
||||
@classmethod
|
||||
def import_core_kernels(cls) -> None:
|
||||
pass
|
||||
def import_kernels(cls) -> None:
|
||||
# Do not import vllm._C
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
|
||||
Reference in New Issue
Block a user