diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 1518e518e9..1ad5e6e8e4 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -10,7 +10,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type import torch import vllm_hpu_extension.ops as ops -from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache +from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, + VLLMKVCache) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, @@ -137,9 +138,17 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '0').lower() in ['1', 'true'] + self.fused_scaled_dot_product_attention = None if self.prefill_usefusedsdpa: assert alibi_slopes is None, \ 'Prefill with FusedSDPA not supported with alibi slopes!' + try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA + self.fused_scaled_dot_product_attention = ModuleFusedSDPA( + FusedSDPA) + except ImportError: + logger().warning("Could not import HPU FusedSDPA kernel. " + "vLLM will use native implementation.") suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: @@ -227,6 +236,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): matmul_qk_op=self.matmul_qk, softmax_op=self.softmax, matmul_av_op=self.matmul_av, + fsdpa_op=self.fused_scaled_dot_product_attention, ) output = out.reshape(batch_size, seq_len, hidden_size) else: