Add more dims for batch invariant shims (#27489)

Signed-off-by: Bram Wasti <bwasti@meta.com>
Signed-off-by: Bram Wasti <bwasti@fb.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Bram Wasti
2025-10-30 01:28:45 -04:00
committed by GitHub
parent 8bff831f0a
commit ded8ada86a

View File

@ -478,9 +478,48 @@ def matmul_batch_invariant(a, b, *, out=None):
elif a.ndim == 3 and b.ndim == 3:
# Handle batched case like bmm
return bmm_batch_invariant(a, b, out=out)
elif a.ndim == 3 and b.ndim == 2:
# Handle 3D x 2D: common for linear layers
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
# Reshape to 2D, do mm, reshape back
batch, seq, hidden = a.shape
a_2d = a.reshape(-1, hidden)
result_2d = matmul_persistent(a_2d, b)
result = result_2d.reshape(batch, seq, -1)
if out is not None:
out.copy_(result)
return out
return result
elif a.ndim == 2 and b.ndim == 3:
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
# By broadcasting `a` to 3D, we can reuse the batched matrix
# multiplication logic.
a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1)
return bmm_batch_invariant(a_expanded, b, out=out)
elif a.ndim == 4 and b.ndim == 4:
# Handle 4D attention tensors: [batch, heads, seq, dim]
# Reshape to 3D, process, reshape back
batch, heads, seq_a, dim_a = a.shape
_, _, dim_b, seq_b = b.shape
# Reshape to [batch*heads, seq_a, dim_a]
a_3d = a.reshape(batch * heads, seq_a, dim_a)
b_3d = b.reshape(batch * heads, dim_b, seq_b)
# Do batched matmul
result_3d = bmm_batch_invariant(a_3d, b_3d)
# Reshape back to [batch, heads, seq_a, seq_b]
result = result_3d.reshape(batch, heads, seq_a, seq_b)
if out is not None:
out.copy_(result)
return out
return result
else:
raise ValueError(
f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, "
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
f"3D x 2D, 2D x 3D, and 4D x 4D, "
f"got shapes {a.shape} and {b.shape}"
)
@ -667,7 +706,8 @@ def rms_norm_batch_invariant(
def linear_batch_invariant(input, weight, bias=None):
output = mm_batch_invariant(input, weight.t())
output = matmul_batch_invariant(input, weight.t())
if bias is not None:
output = output + bias
return output