Add more dims for batch invariant shims (#27489)
Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Bram Wasti <bwasti@fb.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@ -478,9 +478,48 @@ def matmul_batch_invariant(a, b, *, out=None):
|
||||
elif a.ndim == 3 and b.ndim == 3:
|
||||
# Handle batched case like bmm
|
||||
return bmm_batch_invariant(a, b, out=out)
|
||||
elif a.ndim == 3 and b.ndim == 2:
|
||||
# Handle 3D x 2D: common for linear layers
|
||||
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
|
||||
# Reshape to 2D, do mm, reshape back
|
||||
batch, seq, hidden = a.shape
|
||||
a_2d = a.reshape(-1, hidden)
|
||||
result_2d = matmul_persistent(a_2d, b)
|
||||
result = result_2d.reshape(batch, seq, -1)
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
elif a.ndim == 2 and b.ndim == 3:
|
||||
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
|
||||
# By broadcasting `a` to 3D, we can reuse the batched matrix
|
||||
# multiplication logic.
|
||||
a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1)
|
||||
return bmm_batch_invariant(a_expanded, b, out=out)
|
||||
elif a.ndim == 4 and b.ndim == 4:
|
||||
# Handle 4D attention tensors: [batch, heads, seq, dim]
|
||||
# Reshape to 3D, process, reshape back
|
||||
batch, heads, seq_a, dim_a = a.shape
|
||||
_, _, dim_b, seq_b = b.shape
|
||||
|
||||
# Reshape to [batch*heads, seq_a, dim_a]
|
||||
a_3d = a.reshape(batch * heads, seq_a, dim_a)
|
||||
b_3d = b.reshape(batch * heads, dim_b, seq_b)
|
||||
|
||||
# Do batched matmul
|
||||
result_3d = bmm_batch_invariant(a_3d, b_3d)
|
||||
|
||||
# Reshape back to [batch, heads, seq_a, seq_b]
|
||||
result = result_3d.reshape(batch, heads, seq_a, seq_b)
|
||||
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
else:
|
||||
raise ValueError(
|
||||
f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, "
|
||||
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
|
||||
f"3D x 2D, 2D x 3D, and 4D x 4D, "
|
||||
f"got shapes {a.shape} and {b.shape}"
|
||||
)
|
||||
|
||||
@ -667,7 +706,8 @@ def rms_norm_batch_invariant(
|
||||
|
||||
|
||||
def linear_batch_invariant(input, weight, bias=None):
|
||||
output = mm_batch_invariant(input, weight.t())
|
||||
output = matmul_batch_invariant(input, weight.t())
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user