[Bugfix] Add device assertion to TorchSDPA (#5402)

This commit is contained in:
Li, Jiang
2024-06-13 03:58:53 +08:00
committed by GitHub
parent 1a8bfd92d5
commit c3c2903e72

View File

@ -58,6 +58,9 @@ def get_attn_backend(
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
# TODO: make XPU backend available here.
assert is_cpu(), RuntimeError(
"Torch SDPA backend is only used for the CPU device.")
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend