[Hardware][Intel-Gaudi] Enable FusedSDPA support for Intel Gaudi (HPU)
This commit is contained in:
committed by
GitHub
parent
4c3aac51e1
commit
af8486de49
@ -10,7 +10,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import vllm_hpu_extension.ops as ops
|
||||
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
|
||||
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
|
||||
VLLMKVCache)
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
@ -137,9 +138,17 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
|
||||
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
|
||||
'0').lower() in ['1', 'true']
|
||||
self.fused_scaled_dot_product_attention = None
|
||||
if self.prefill_usefusedsdpa:
|
||||
assert alibi_slopes is None, \
|
||||
'Prefill with FusedSDPA not supported with alibi slopes!'
|
||||
try:
|
||||
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
|
||||
FusedSDPA)
|
||||
except ImportError:
|
||||
logger().warning("Could not import HPU FusedSDPA kernel. "
|
||||
"vLLM will use native implementation.")
|
||||
|
||||
suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
|
||||
if head_size not in suppored_head_sizes:
|
||||
@ -227,6 +236,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
matmul_qk_op=self.matmul_qk,
|
||||
softmax_op=self.softmax,
|
||||
matmul_av_op=self.matmul_av,
|
||||
fsdpa_op=self.fused_scaled_dot_product_attention,
|
||||
)
|
||||
output = out.reshape(batch_size, seq_len, hidden_size)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user