fix gqa issue for blackwell fmha.py (#2599)

This commit is contained in:
Linfeng Zheng
2025-08-28 23:15:20 +08:00
committed by GitHub
parent a49a78ffef
commit 9ca7e877b2

View File

@ -437,7 +437,7 @@ class BlackwellFusedMultiHeadAttentionForward:
# (s, d, ((h_r, h_k), b))
q_layout = cute.make_layout(
(s_q, d, ((h_r, h_k), b_qo)),
stride=(d * h_r * h_k, 1, ((d * h_k, d), stride_b_qo)),
stride=(d * h_r * h_k, 1, ((d, d * h_r), stride_b_qo)),
)
q = cute.make_tensor(q_iter + qo_offset, q_layout)
# (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast
@ -455,7 +455,7 @@ class BlackwellFusedMultiHeadAttentionForward:
# (s, d, ((h_r, h_k), b))
o_layout = cute.make_layout(
(s_q, d, ((h_r, h_k), b_qo)),
stride=(d * h_r * h_k, 1, ((d * h_k, d), stride_b_qo)),
stride=(d * h_r * h_k, 1, ((d, d * h_r), stride_b_qo)),
)
o = cute.make_tensor(o_iter + qo_offset, o_layout)
@ -2953,32 +2953,13 @@ def run(
h_q = q.shape[2]
h_k = k.shape[2]
if not h_q == h_k:
repeat_factor = h_q // h_k
# nested tensor can not be broadcasted directly
if k.is_nested:
k_offsets = k.offsets()
v_offsets = v.offsets()
k_values = k.values().repeat(1, repeat_factor, 1)
v_values = v.values().repeat(1, repeat_factor, 1)
k = torch.nested.nested_tensor_from_jagged(
values=k_values, offsets=k_offsets
)
v = torch.nested.nested_tensor_from_jagged(
values=v_values, offsets=v_offsets
)
else:
k = k.repeat(1, 1, repeat_factor, 1)
v = v.repeat(1, 1, repeat_factor, 1)
# as we initialize q, k, v with shape (b, s, h, d) and SDPA of torch needs them to be (b, h, s, d)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# For the situation that torch has not supported, we need to handle it manually
situation1 = is_causal and (q.is_nested or k.is_nested)
situation1 = (h_q != h_k or is_causal) and (q.is_nested or k.is_nested)
situation2 = (q.is_nested and not k.is_nested) or (
not q.is_nested and k.is_nested
)
@ -2999,6 +2980,7 @@ def run(
dropout_p=0.0,
scale=scale_softmax,
is_causal=is_causal,
enable_gqa=(h_q != h_k),
)
ref_i = ref_i.transpose(0, 1) * scale_output
ref_list.append(ref_i)
@ -3015,6 +2997,7 @@ def run(
dropout_p=0.0,
scale=scale_softmax,
is_causal=is_causal,
enable_gqa=(h_q != h_k),
)
ref = ref.transpose(1, 2) * scale_output
return ref