From 9ca7e877b24cef095fef92a7aa25d3795b74f69d Mon Sep 17 00:00:00 2001 From: Linfeng Zheng Date: Thu, 28 Aug 2025 23:15:20 +0800 Subject: [PATCH] fix gqa issue for blackwell fmha.py (#2599) --- examples/python/CuTeDSL/blackwell/fmha.py | 27 +++++------------------ 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/examples/python/CuTeDSL/blackwell/fmha.py b/examples/python/CuTeDSL/blackwell/fmha.py index 537d9b43..259ddb85 100644 --- a/examples/python/CuTeDSL/blackwell/fmha.py +++ b/examples/python/CuTeDSL/blackwell/fmha.py @@ -437,7 +437,7 @@ class BlackwellFusedMultiHeadAttentionForward: # (s, d, ((h_r, h_k), b)) q_layout = cute.make_layout( (s_q, d, ((h_r, h_k), b_qo)), - stride=(d * h_r * h_k, 1, ((d * h_k, d), stride_b_qo)), + stride=(d * h_r * h_k, 1, ((d, d * h_r), stride_b_qo)), ) q = cute.make_tensor(q_iter + qo_offset, q_layout) # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast @@ -455,7 +455,7 @@ class BlackwellFusedMultiHeadAttentionForward: # (s, d, ((h_r, h_k), b)) o_layout = cute.make_layout( (s_q, d, ((h_r, h_k), b_qo)), - stride=(d * h_r * h_k, 1, ((d * h_k, d), stride_b_qo)), + stride=(d * h_r * h_k, 1, ((d, d * h_r), stride_b_qo)), ) o = cute.make_tensor(o_iter + qo_offset, o_layout) @@ -2953,32 +2953,13 @@ def run( h_q = q.shape[2] h_k = k.shape[2] - if not h_q == h_k: - repeat_factor = h_q // h_k - # nested tensor can not be broadcasted directly - if k.is_nested: - k_offsets = k.offsets() - v_offsets = v.offsets() - k_values = k.values().repeat(1, repeat_factor, 1) - v_values = v.values().repeat(1, repeat_factor, 1) - - k = torch.nested.nested_tensor_from_jagged( - values=k_values, offsets=k_offsets - ) - v = torch.nested.nested_tensor_from_jagged( - values=v_values, offsets=v_offsets - ) - else: - k = k.repeat(1, 1, repeat_factor, 1) - v = v.repeat(1, 1, repeat_factor, 1) - # as we initialize q, k, v with shape (b, s, h, d) and SDPA of torch needs them to be (b, h, s, d) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # For the situation that torch has not supported, we need to handle it manually - situation1 = is_causal and (q.is_nested or k.is_nested) + situation1 = (h_q != h_k or is_causal) and (q.is_nested or k.is_nested) situation2 = (q.is_nested and not k.is_nested) or ( not q.is_nested and k.is_nested ) @@ -2999,6 +2980,7 @@ def run( dropout_p=0.0, scale=scale_softmax, is_causal=is_causal, + enable_gqa=(h_q != h_k), ) ref_i = ref_i.transpose(0, 1) * scale_output ref_list.append(ref_i) @@ -3015,6 +2997,7 @@ def run( dropout_p=0.0, scale=scale_softmax, is_causal=is_causal, + enable_gqa=(h_q != h_k), ) ref = ref.transpose(1, 2) * scale_output return ref