mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-08 08:17:24 +08:00
Switch to optimized_attention_for_device in camera DA3 module.
This commit is contained in:
@ -18,6 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from .transform import affine_inverse, extri_intri_to_pose_encoding
|
||||
|
||||
|
||||
@ -74,11 +75,10 @@ class _Attention(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4) # 3, B, h, N, d
|
||||
q, k, v = qkv.unbind(0)
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
out = out.transpose(1, 2).reshape(B, N, C)
|
||||
qkv = self.qkv(x).reshape(B, N, 3, C)
|
||||
q, k, v = qkv.unbind(2) # each (B, N, C)
|
||||
attn_fn = optimized_attention_for_device(x.device, small_input=True)
|
||||
out = attn_fn(q, k, v, heads=self.num_heads)
|
||||
return self.proj(out)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user