[Bugfix][Model] OLMo 2: split qkv correctly for GQA and MQA (#13687)

This commit is contained in:
Shane A
2025-02-21 22:07:45 -08:00
committed by GitHub
parent 68d630a0c7
commit 9a1f1da5d1

View File

@ -157,7 +157,7 @@ class Olmo2Attention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)