fix gqa issue for blackwell fmha.py (#2599)
This commit is contained in:
@ -437,7 +437,7 @@ class BlackwellFusedMultiHeadAttentionForward:
|
||||
# (s, d, ((h_r, h_k), b))
|
||||
q_layout = cute.make_layout(
|
||||
(s_q, d, ((h_r, h_k), b_qo)),
|
||||
stride=(d * h_r * h_k, 1, ((d * h_k, d), stride_b_qo)),
|
||||
stride=(d * h_r * h_k, 1, ((d, d * h_r), stride_b_qo)),
|
||||
)
|
||||
q = cute.make_tensor(q_iter + qo_offset, q_layout)
|
||||
# (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast
|
||||
@ -455,7 +455,7 @@ class BlackwellFusedMultiHeadAttentionForward:
|
||||
# (s, d, ((h_r, h_k), b))
|
||||
o_layout = cute.make_layout(
|
||||
(s_q, d, ((h_r, h_k), b_qo)),
|
||||
stride=(d * h_r * h_k, 1, ((d * h_k, d), stride_b_qo)),
|
||||
stride=(d * h_r * h_k, 1, ((d, d * h_r), stride_b_qo)),
|
||||
)
|
||||
o = cute.make_tensor(o_iter + qo_offset, o_layout)
|
||||
|
||||
@ -2953,32 +2953,13 @@ def run(
|
||||
h_q = q.shape[2]
|
||||
h_k = k.shape[2]
|
||||
|
||||
if not h_q == h_k:
|
||||
repeat_factor = h_q // h_k
|
||||
# nested tensor can not be broadcasted directly
|
||||
if k.is_nested:
|
||||
k_offsets = k.offsets()
|
||||
v_offsets = v.offsets()
|
||||
k_values = k.values().repeat(1, repeat_factor, 1)
|
||||
v_values = v.values().repeat(1, repeat_factor, 1)
|
||||
|
||||
k = torch.nested.nested_tensor_from_jagged(
|
||||
values=k_values, offsets=k_offsets
|
||||
)
|
||||
v = torch.nested.nested_tensor_from_jagged(
|
||||
values=v_values, offsets=v_offsets
|
||||
)
|
||||
else:
|
||||
k = k.repeat(1, 1, repeat_factor, 1)
|
||||
v = v.repeat(1, 1, repeat_factor, 1)
|
||||
|
||||
# as we initialize q, k, v with shape (b, s, h, d) and SDPA of torch needs them to be (b, h, s, d)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# For the situation that torch has not supported, we need to handle it manually
|
||||
situation1 = is_causal and (q.is_nested or k.is_nested)
|
||||
situation1 = (h_q != h_k or is_causal) and (q.is_nested or k.is_nested)
|
||||
situation2 = (q.is_nested and not k.is_nested) or (
|
||||
not q.is_nested and k.is_nested
|
||||
)
|
||||
@ -2999,6 +2980,7 @@ def run(
|
||||
dropout_p=0.0,
|
||||
scale=scale_softmax,
|
||||
is_causal=is_causal,
|
||||
enable_gqa=(h_q != h_k),
|
||||
)
|
||||
ref_i = ref_i.transpose(0, 1) * scale_output
|
||||
ref_list.append(ref_i)
|
||||
@ -3015,6 +2997,7 @@ def run(
|
||||
dropout_p=0.0,
|
||||
scale=scale_softmax,
|
||||
is_causal=is_causal,
|
||||
enable_gqa=(h_q != h_k),
|
||||
)
|
||||
ref = ref.transpose(1, 2) * scale_output
|
||||
return ref
|
||||
|
||||
Reference in New Issue
Block a user