Switch to optimized_attention_for_device in camera DA3 module.

This commit is contained in:
Talmaj Marinc
2026-06-01 16:09:39 +02:00
parent f57ebb4cf4
commit 15f4dc401a

View File

@ -18,6 +18,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention_for_device
from .transform import affine_inverse, extri_intri_to_pose_encoding
@ -74,11 +75,10 @@ class _Attention(nn.Module):
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # 3, B, h, N, d
q, k, v = qkv.unbind(0)
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(B, N, C)
qkv = self.qkv(x).reshape(B, N, 3, C)
q, k, v = qkv.unbind(2) # each (B, N, C)
attn_fn = optimized_attention_for_device(x.device, small_input=True)
out = attn_fn(q, k, v, heads=self.num_heads)
return self.proj(out)