[Bug] Fix DeepGEMM Attention Test (#26423)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@ -184,6 +184,7 @@ ba = "ba"
|
||||
|
||||
[tool.typos.type.py.extend-words]
|
||||
ba = "ba"
|
||||
nd = "nd"
|
||||
|
||||
[tool.typos.type.cpp]
|
||||
extend-glob = ["*.cu"]
|
||||
|
||||
@ -82,8 +82,7 @@ def _ref_fp8_mqa_logits(
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
||||
)
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,and->hmn", q, k)
|
||||
score = torch.einsum("mhd,nd->hmn", q, k)
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user