mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 19:35:57 +08:00
Compare commits
31 Commits
feat/api-n
...
depth-anyt
| Author | SHA1 | Date | |
|---|---|---|---|
| aefc61f42d | |||
| 15c096aa16 | |||
| 2dd9d96d4a | |||
| 2ed1f36471 | |||
| 7cb2394630 | |||
| ddb8739963 | |||
| e3a74d1696 | |||
| 359da6d0b4 | |||
| b65bdc3737 | |||
| d66f385502 | |||
| 22982da481 | |||
| 9e30a0b56c | |||
| 81c8afb36d | |||
| 102704f9fb | |||
| ccdc6517fd | |||
| 33c0421153 | |||
| 50729162f4 | |||
| d6aad2f8c7 | |||
| 67278c5851 | |||
| 864249d3c1 | |||
| d7d149b754 | |||
| 47d3d90380 | |||
| 33235ab099 | |||
| a9a993130e | |||
| c3c9ce39f9 | |||
| fb19982410 | |||
| 19c25c8cbd | |||
| 33701de3a8 | |||
| 82e4db5d4a | |||
| 1ffa010952 | |||
| 9b19fab3da |
@ -1,7 +1,13 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.text_encoders.bert import BertAttention
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.ldm.depth_anything_3.reference_view_selector import (
|
||||
select_reference_view, reorder_by_reference, restore_original_order,
|
||||
THRESH_FOR_REF_SELECTION,
|
||||
)
|
||||
|
||||
|
||||
class Dino2AttentionOutput(torch.nn.Module):
|
||||
@ -14,13 +20,42 @@ class Dino2AttentionOutput(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2AttentionBlock(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations,
|
||||
qk_norm=False):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.head_dim = embed_dim // heads
|
||||
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
if qk_norm:
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
|
||||
def forward(self, x, mask, optimized_attention):
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
def forward(self, x, mask, optimized_attention, pos=None, rope=None):
|
||||
# Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions).
|
||||
if self.q_norm is None and rope is None:
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
|
||||
# DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE.
|
||||
attn = self.attention
|
||||
B, N, C = x.shape
|
||||
h = self.heads
|
||||
d = self.head_dim
|
||||
q = attn.query(x).view(B, N, h, d).transpose(1, 2)
|
||||
k = attn.key(x).view(B, N, h, d).transpose(1, 2)
|
||||
v = attn.value(x).view(B, N, h, d).transpose(1, 2)
|
||||
if self.q_norm is not None:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if rope is not None and pos is not None:
|
||||
q = rope(q, pos)
|
||||
k = rope(k, pos)
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
out = out.transpose(1, 2).reshape(B, N, C)
|
||||
return self.output(out)
|
||||
|
||||
|
||||
class LayerScale(torch.nn.Module):
|
||||
@ -64,9 +99,11 @@ class SwiGLUFFN(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn,
|
||||
qk_norm=False):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations,
|
||||
qk_norm=qk_norm)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
if use_swiglu_ffn:
|
||||
@ -76,19 +113,90 @@ class Dino2Block(torch.nn.Module):
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, optimized_attention):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
||||
def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention,
|
||||
pos=pos, rope=rope))
|
||||
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
|
||||
# -----------------------------------------------------------------------------
|
||||
# 2D Rotary position embedding (DA3 extension)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _PositionGetter:
|
||||
"""Cache (h, w) -> flat (y, x) position grid used to feed ``rope``."""
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict = {}
|
||||
|
||||
def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor:
|
||||
key = (height, width, device)
|
||||
if key not in self._cache:
|
||||
y = torch.arange(height, device=device)
|
||||
x = torch.arange(width, device=device)
|
||||
self._cache[key] = torch.cartesian_prod(y, x)
|
||||
cached = self._cache[key]
|
||||
return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
||||
|
||||
|
||||
class RotaryPositionEmbedding2D(torch.nn.Module):
|
||||
"""2D RoPE used by DA3-Small/Base. No learnable parameters."""
|
||||
|
||||
def __init__(self, frequency: float = 100.0):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
for _ in range(num_layers)])
|
||||
self.base_frequency = frequency
|
||||
self._freq_cache: dict = {}
|
||||
|
||||
def _components(self, dim: int, seq_len: int, device, dtype):
|
||||
key = (dim, seq_len, device, dtype)
|
||||
if key not in self._freq_cache:
|
||||
exp = torch.arange(0, dim, 2, device=device).float() / dim
|
||||
inv_freq = 1.0 / (self.base_frequency ** exp)
|
||||
pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
ang = torch.einsum("i,j->ij", pos, inv_freq)
|
||||
ang = ang.to(dtype)
|
||||
ang = torch.cat((ang, ang), dim=-1)
|
||||
self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype))
|
||||
return self._freq_cache[key]
|
||||
|
||||
@staticmethod
|
||||
def _rotate(x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1]
|
||||
x1, x2 = x[..., : d // 2], x[..., d // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _apply_1d(self, tokens, positions, cos_c, sin_c):
|
||||
cos = F.embedding(positions, cos_c)[:, None, :, :]
|
||||
sin = F.embedding(positions, sin_c)[:, None, :, :]
|
||||
return (tokens * cos) + (self._rotate(tokens) * sin)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
feature_dim = tokens.size(-1) // 2
|
||||
max_pos = int(positions.max()) + 1
|
||||
cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype)
|
||||
v, h = tokens.chunk(2, dim=-1)
|
||||
v = self._apply_1d(v, positions[..., 0], cos_c, sin_c)
|
||||
h = self._apply_1d(h, positions[..., 1], cos_c, sin_c)
|
||||
return torch.cat((v, h), dim=-1)
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn,
|
||||
qknorm_start: int = -1):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([
|
||||
Dino2Block(
|
||||
dim, num_heads, layer_norm_eps, dtype, device, operations,
|
||||
use_swiglu_ffn=use_swiglu_ffn,
|
||||
qk_norm=(qknorm_start != -1 and i >= qknorm_start),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
# Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions).
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
if intermediate_output is not None:
|
||||
@ -122,16 +230,27 @@ class Dino2PatchEmbeddings(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Embeddings(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
def __init__(self, dim, dtype, device, operations,
|
||||
patch_size: int = 14, image_size: int = 518,
|
||||
use_mask_token: bool = True,
|
||||
num_camera_tokens: int = 0):
|
||||
super().__init__()
|
||||
patch_size = 14
|
||||
image_size = 518
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
|
||||
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key.
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
if use_mask_token:
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.mask_token = None
|
||||
if num_camera_tokens > 0:
|
||||
# DA3 stores (ref_token, src_token) pairs that get injected at the
|
||||
# alt-attn boundary; see ``Dinov2Model._inject_camera_token``.
|
||||
self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.camera_token = None
|
||||
|
||||
def interpolate_pos_encoding(self, x, h_pixels, w_pixels):
|
||||
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32)
|
||||
@ -140,12 +259,22 @@ class Dino2Embeddings(torch.nn.Module):
|
||||
patch_pos = pos_embed[:, 1:]
|
||||
N = patch_pos.shape[1]
|
||||
M = int(N ** 0.5)
|
||||
assert N == M * M, f"DINOv2 position grid must be square, got N={N} patches (sqrt={M})"
|
||||
h0 = h_pixels // self.patch_size
|
||||
w0 = w_pixels // self.patch_size
|
||||
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
|
||||
# +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
|
||||
# scale_factor is (height_scale, width_scale) -- height MUST come first;
|
||||
# swapping these only happens to work for square inputs and breaks
|
||||
# non-square paths like DA3-Small / DA3-Base multi-view.
|
||||
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M)
|
||||
|
||||
patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2)
|
||||
patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False)
|
||||
assert (h0, w0) == patch_pos.shape[-2:], (
|
||||
f"Interpolated pos-embed grid {tuple(patch_pos.shape[-2:])} does not match "
|
||||
f"target patch grid ({h0}, {w0}) for input {h_pixels}x{w_pixels} (patch_size={self.patch_size}); "
|
||||
f"check scale_factor axis order and +0.1 rounding workaround"
|
||||
)
|
||||
patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)
|
||||
|
||||
@ -161,6 +290,21 @@ class Dino2Embeddings(torch.nn.Module):
|
||||
|
||||
|
||||
class Dinov2Model(torch.nn.Module):
|
||||
"""DINOv2 vision backbone.
|
||||
|
||||
Supports two operating modes:
|
||||
|
||||
* **CLIP-vision DINOv2** (default): vanilla DINOv2-ViT used for
|
||||
``ClipVisionModel`` and SigLIP-style image encoding.
|
||||
* **Depth Anything 3** extensions (opt-in via config keys): 2D RoPE,
|
||||
QK-norm, alternating local/global attention, camera-token injection,
|
||||
``cat_token`` output and multi-layer feature extraction. These are
|
||||
enabled when the corresponding fields (``alt_start``, ``qknorm_start``,
|
||||
``rope_start``, ``cat_token``) are set in ``config_dict``. When all of
|
||||
them are at their disabled defaults this module behaves identically to
|
||||
the historical ``Dinov2Model``.
|
||||
"""
|
||||
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
@ -168,12 +312,51 @@ class Dinov2Model(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
|
||||
patch_size = config_dict.get("patch_size", 14)
|
||||
image_size = config_dict.get("image_size", 518)
|
||||
use_mask_token = config_dict.get("use_mask_token", True)
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
# DA3 extensions (all default to disabled).
|
||||
self.alt_start = config_dict.get("alt_start", -1)
|
||||
self.qknorm_start = config_dict.get("qknorm_start", -1)
|
||||
self.rope_start = config_dict.get("rope_start", -1)
|
||||
self.cat_token = config_dict.get("cat_token", False)
|
||||
rope_freq = config_dict.get("rope_freq", 100.0)
|
||||
|
||||
self.embed_dim = dim
|
||||
self.patch_size = patch_size
|
||||
self.num_register_tokens = 0
|
||||
self.patch_start_idx = 1
|
||||
|
||||
if self.rope_start != -1 and rope_freq > 0:
|
||||
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq)
|
||||
self._position_getter = _PositionGetter()
|
||||
else:
|
||||
self.rope = None
|
||||
self._position_getter = None
|
||||
|
||||
# camera_token shape: (1, 2, dim) -> (ref_token, src_token).
|
||||
num_cam_tokens = 2 if self.alt_start != -1 else 0
|
||||
|
||||
self.embeddings = Dino2Embeddings(
|
||||
dim, dtype, device, operations,
|
||||
patch_size=patch_size, image_size=image_size,
|
||||
use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens,
|
||||
)
|
||||
self.encoder = Dino2Encoder(
|
||||
dim, heads, layer_norm_eps, num_layers, dtype, device, operations,
|
||||
use_swiglu_ffn=use_swiglu_ffn,
|
||||
qknorm_start=self.qknorm_start,
|
||||
)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
if self.alt_start != -1:
|
||||
raise RuntimeError(
|
||||
"Dinov2Model.forward() is the backward-compatible CLIP-vision path and does not "
|
||||
"apply DA3 extensions (RoPE, alternating attention, camera-token injection). "
|
||||
"Use get_intermediate_layers_da3() for Depth Anything 3 models."
|
||||
)
|
||||
x = self.embeddings(pixel_values)
|
||||
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||
x = self.layernorm(x)
|
||||
@ -181,6 +364,21 @@ class Dinov2Model(torch.nn.Module):
|
||||
return x, i, pooled_output, None
|
||||
|
||||
def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
|
||||
"""Single-view multi-layer feature extraction (MoGe / vanilla DINOv2).
|
||||
|
||||
For the multi-view Depth Anything 3 path (RoPE, alt-attention,
|
||||
camera-token injection, ref-view selection, cat_token), use
|
||||
:meth:`get_intermediate_layers_da3` instead.
|
||||
|
||||
Args:
|
||||
pixel_values: ``(B, 3, H, W)`` single-view input.
|
||||
indices: layer indices to extract; supports negative indexing.
|
||||
apply_norm: if True, apply the final layernorm to each output.
|
||||
|
||||
Returns:
|
||||
list of ``(patch_tokens, cls_token)`` tuples with shapes
|
||||
``(B, N_patch, C)`` and ``(B, C)`` (one entry per ``indices``).
|
||||
"""
|
||||
x = self.embeddings(pixel_values)
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
n_layers = len(self.encoder.layer)
|
||||
@ -197,3 +395,166 @@ class Dinov2Model(torch.nn.Module):
|
||||
if i >= max_idx:
|
||||
break
|
||||
return [cache[i] for i in resolved]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Depth Anything 3 forward
|
||||
# ------------------------------------------------------------------
|
||||
def _prepare_rope_positions(self, B, S, H, W, device):
|
||||
if self.rope is None:
|
||||
return None, None
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
pos = self._position_getter(B * S, ph, pw, device=device)
|
||||
# Shift so the cls/cam token at position 0 is reserved for "no diff".
|
||||
pos = pos + 1
|
||||
cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype)
|
||||
# Per-view local: real grid positions for patches, 0 for cls token.
|
||||
pos_local = torch.cat([cls_pos, pos], dim=1)
|
||||
# Global (across views): same grid positions; cls token still at 0,
|
||||
# but patches share the same positions in every view.
|
||||
pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1)
|
||||
return pos_local, pos_global
|
||||
|
||||
def _inject_camera_token(self, x: torch.Tensor, B: int, S: int,
|
||||
cam_token: "torch.Tensor | None") -> torch.Tensor:
|
||||
# x: (B, S, N, C). Replace token at index 0 with the camera token.
|
||||
if cam_token is not None:
|
||||
inj = cam_token
|
||||
else:
|
||||
ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype)
|
||||
ref_token = ct[:, :1].expand(B, -1, -1)
|
||||
src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1)
|
||||
inj = torch.cat([ref_token, src_token], dim=1)
|
||||
x = x.clone()
|
||||
x[:, :, 0] = inj
|
||||
return x
|
||||
|
||||
def get_intermediate_layers_da3(self, pixel_values, out_layers, cam_token=None,
|
||||
ref_view_strategy="saddle_balanced",
|
||||
export_feat_layers=None):
|
||||
"""Multi-view multi-layer feature extraction used by Depth Anything 3.
|
||||
|
||||
Adds RoPE positions, alternating local/global attention across views,
|
||||
camera-token injection, reference-view selection/reordering,
|
||||
``cat_token`` output and optional auxiliary feature exports on top of
|
||||
the vanilla DINOv2 path. For the single-view MoGe / CLIP-vision use
|
||||
case, see :meth:`get_intermediate_layers`.
|
||||
|
||||
Args:
|
||||
pixel_values: ``(B, S, 3, H, W)`` views or ``(B, 3, H, W)``.
|
||||
out_layers: indices into ``self.encoder.layer``.
|
||||
cam_token: optional ``(B, S, dim)`` camera token to inject at
|
||||
``alt_start``. If ``None`` and the model has its own
|
||||
``camera_token`` parameter, that is used.
|
||||
ref_view_strategy: when ``S >= 3`` and ``cam_token is None``,
|
||||
pick a reference view via this strategy and move it to
|
||||
position 0 right before the first alt-attention block.
|
||||
The original view order is restored on the way out.
|
||||
export_feat_layers: optional iterable of layer indices whose
|
||||
local attention outputs to also return as auxiliary
|
||||
features (``(B, S, N_patch, C)`` after final norm). Used
|
||||
by the multi-view path to expose intermediate features
|
||||
to the nested-architecture wrapper.
|
||||
|
||||
Returns:
|
||||
``(layer_outputs, aux_outputs)`` where ``layer_outputs`` is a
|
||||
list of ``(patch_tokens, cls_or_cam_token)`` tuples (one per
|
||||
``out_layers`` entry) and ``aux_outputs`` is a list of
|
||||
``(B, S, N_patch, C)`` features for ``export_feat_layers``
|
||||
(empty list when not requested).
|
||||
"""
|
||||
if pixel_values.ndim == 4:
|
||||
pixel_values = pixel_values.unsqueeze(1)
|
||||
assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \
|
||||
f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}"
|
||||
B, S, _, H, W = pixel_values.shape
|
||||
|
||||
# Patch + cls + (interpolated) pos embed for each view.
|
||||
x = pixel_values.reshape(B * S, 3, H, W)
|
||||
x = self.embeddings(x) # (B*S, 1+N, C)
|
||||
x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C)
|
||||
|
||||
pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device)
|
||||
# ``optimized_attention`` is only used by blocks without QK-norm/RoPE
|
||||
# (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA.
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
out_set = set(out_layers)
|
||||
export_set = set(export_feat_layers) if export_feat_layers else set()
|
||||
outputs: list[torch.Tensor] = []
|
||||
aux_outputs: list[torch.Tensor] = []
|
||||
local_x = x
|
||||
b_idx = None
|
||||
|
||||
|
||||
for i, blk in enumerate(self.encoder.layer):
|
||||
apply_rope = self.rope is not None and i >= self.rope_start
|
||||
block_rope = self.rope if apply_rope else None
|
||||
l_pos = pos_local if apply_rope else None
|
||||
g_pos = pos_global if apply_rope else None
|
||||
|
||||
# Reference-view selection threshold: matches the upstream constant
|
||||
# ``THRESH_FOR_REF_SELECTION = 3``. Skipped when a user-supplied
|
||||
# cam_token is provided (camera info already pins the geometry).
|
||||
if (self.alt_start != -1 and i == self.alt_start - 1
|
||||
and S >= THRESH_FOR_REF_SELECTION and cam_token is None):
|
||||
b_idx = select_reference_view(x, strategy=ref_view_strategy)
|
||||
x = reorder_by_reference(x, b_idx)
|
||||
local_x = reorder_by_reference(local_x, b_idx)
|
||||
|
||||
if self.alt_start != -1 and i == self.alt_start:
|
||||
x = self._inject_camera_token(x, B, S, cam_token)
|
||||
|
||||
if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1):
|
||||
# Global attention across views: flatten S into the seq dim.
|
||||
t = x.reshape(B, S * x.shape[-2], x.shape[-1])
|
||||
p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None
|
||||
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
|
||||
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
|
||||
else:
|
||||
# Per-view local attention.
|
||||
t = x.reshape(B * S, x.shape[-2], x.shape[-1])
|
||||
p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None
|
||||
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
|
||||
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
|
||||
local_x = x
|
||||
|
||||
if i in out_set:
|
||||
if self.cat_token:
|
||||
out_x = torch.cat([local_x, x], dim=-1)
|
||||
else:
|
||||
out_x = x
|
||||
# Restore original view order on the way out so heads see views
|
||||
# in the user's expected order.
|
||||
if b_idx is not None and self.alt_start != -1:
|
||||
out_x = restore_original_order(out_x, b_idx)
|
||||
outputs.append(out_x)
|
||||
|
||||
if i in export_set:
|
||||
aux = x
|
||||
if b_idx is not None and self.alt_start != -1:
|
||||
aux = restore_original_order(aux, b_idx)
|
||||
aux_outputs.append(aux)
|
||||
|
||||
# Apply final norm. When ``cat_token`` is set, only the right half
|
||||
# ("global" features) is normalised; the left half is left as-is to
|
||||
# match the upstream DA3 head signature.
|
||||
normed: list[torch.Tensor] = []
|
||||
cls_tokens: list[torch.Tensor] = []
|
||||
for out_x in outputs:
|
||||
cls_tokens.append(out_x[:, :, 0])
|
||||
if out_x.shape[-1] == self.embed_dim:
|
||||
normed.append(self.layernorm(out_x))
|
||||
elif out_x.shape[-1] == self.embed_dim * 2:
|
||||
left = out_x[..., :self.embed_dim]
|
||||
right = self.layernorm(out_x[..., self.embed_dim:])
|
||||
normed.append(torch.cat([left, right], dim=-1))
|
||||
else:
|
||||
raise ValueError(f"Unexpected token width: {out_x.shape[-1]}")
|
||||
|
||||
# Drop cls/cam token from the patch sequence.
|
||||
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
|
||||
|
||||
# Final layernorm + drop cls token from auxiliary features too.
|
||||
aux_normed = [self.layernorm(o)[..., 1 + self.num_register_tokens:, :]
|
||||
for o in aux_outputs]
|
||||
return list(zip(normed, cls_tokens)), aux_normed
|
||||
|
||||
@ -799,15 +799,13 @@ class ZImagePixelSpace(ChromaRadiance):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class HiDreamO1Pixel(ChromaRadiance):
|
||||
"""Pixel-space latent format for HiDream-O1.
|
||||
No VAE — model patches/unpatches raw RGB internally with patch_size=32.
|
||||
"""
|
||||
pass
|
||||
|
||||
class PixelDiTPixel(ChromaRadiance):
|
||||
pass
|
||||
|
||||
class CogVideoX(LatentFormat):
|
||||
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
|
||||
|
||||
|
||||
25
comfy/ldm/colormap.py
Normal file
25
comfy/ldm/colormap.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""Colormap utilities for depth and geometry visualisation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def turbo(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Anton Mikhailov polynomial approximation of the Turbo colormap.
|
||||
|
||||
Args:
|
||||
x: Float tensor with values in [0, 1].
|
||||
|
||||
Returns:
|
||||
RGB tensor of the same shape as ``x`` with a trailing size-3 dimension.
|
||||
"""
|
||||
x = x.clamp(0.0, 1.0)
|
||||
x2 = x * x
|
||||
x3 = x2 * x
|
||||
x4 = x2 * x2
|
||||
x5 = x4 * x
|
||||
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
|
||||
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
|
||||
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
|
||||
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
|
||||
7
comfy/ldm/depth_anything_3/__init__.py
Normal file
7
comfy/ldm/depth_anything_3/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Depth Anything 3 - native ComfyUI port (Apache-2.0 monocular variants only).
|
||||
#
|
||||
# Supported variants:
|
||||
# DA3-Small, DA3-Base (vits/vitb backbone, DualDPT head)
|
||||
# DA3Mono-Large, DA3Metric-Large (vitl backbone, DPT head + sky mask)
|
||||
#
|
||||
# Original repo: https://github.com/ByteDance-Seed/Depth-Anything-3
|
||||
204
comfy/ldm/depth_anything_3/camera.py
Normal file
204
comfy/ldm/depth_anything_3/camera.py
Normal file
@ -0,0 +1,204 @@
|
||||
"""Camera-token encoder and decoder for Depth Anything 3.
|
||||
|
||||
* :class:`CameraEnc` takes per-view extrinsics + intrinsics and produces a
|
||||
per-view camera token that gets injected at the alt-attention boundary
|
||||
in the DINOv2 backbone (block ``alt_start``).
|
||||
* :class:`CameraDec` takes the final-layer camera token output by the
|
||||
backbone and predicts a 9-D pose encoding (translation, quaternion,
|
||||
field-of-view).
|
||||
|
||||
The module/parameter names match the upstream ``cam_enc.py``/``cam_dec.py``
|
||||
so HF safetensors load directly with no key remapping (the upstream uses
|
||||
fused QKV linears, which we replicate here).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .transform import affine_inverse, extri_intri_to_pose_encoding
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Building blocks (mirror ``depth_anything_3.model.utils.{attention,block}``)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _Mlp(nn.Module):
|
||||
"""Standard 2-layer MLP with GELU. Matches upstream ``utils.attention.Mlp``."""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None,
|
||||
*, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = operations.Linear(in_features, hidden_features, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
self.fc2 = operations.Linear(hidden_features, out_features, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(F.gelu(self.fc1(x)))
|
||||
|
||||
|
||||
class _LayerScale(nn.Module):
|
||||
"""Per-channel learnable scaling. Matches upstream ``LayerScale``."""
|
||||
|
||||
def __init__(self, dim, *, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.gamma.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
class _Attention(nn.Module):
|
||||
"""Self-attention with fused QKV projection.
|
||||
|
||||
Mirrors upstream ``utils.attention.Attention``; layout matches the
|
||||
HF safetensors (``attn.qkv.{weight,bias}`` and ``attn.proj.{weight,bias}``).
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_heads,
|
||||
*, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4) # 3, B, h, N, d
|
||||
q, k, v = qkv.unbind(0)
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
out = out.transpose(1, 2).reshape(B, N, C)
|
||||
return self.proj(out)
|
||||
|
||||
|
||||
class _Block(nn.Module):
|
||||
"""Pre-norm transformer block with LayerScale.
|
||||
|
||||
Used by :class:`CameraEnc`. Layout follows upstream ``utils.block.Block``.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01,
|
||||
*, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.attn = _Attention(dim, num_heads,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.ls1 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
|
||||
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio),
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.ls2 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.ls1(self.attn(self.norm1(x)))
|
||||
x = x + self.ls2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class CameraEnc(nn.Module):
|
||||
"""Encode per-view (extrinsics, intrinsics) into a camera token.
|
||||
|
||||
Maps a 9-D pose-encoding vector through a small MLP up to the backbone's
|
||||
``embed_dim``, then runs ``trunk_depth`` transformer blocks. The output
|
||||
has shape ``(B, S, embed_dim)`` and is injected at block ``alt_start``
|
||||
of the DINOv2 backbone in place of the cls token.
|
||||
|
||||
Parameters mirror the upstream ``cam_enc.py`` so HF weights load directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_out: int = 1024,
|
||||
dim_in: int = 9,
|
||||
trunk_depth: int = 4,
|
||||
target_dim: int = 9,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
init_values: float = 0.01,
|
||||
*,
|
||||
device=None, dtype=None, operations=None,
|
||||
**_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.target_dim = target_dim
|
||||
self.trunk_depth = trunk_depth
|
||||
self.trunk = nn.Sequential(*[
|
||||
_Block(dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
||||
init_values=init_values,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(trunk_depth)
|
||||
])
|
||||
self.token_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
|
||||
self.trunk_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
|
||||
self.pose_branch = _Mlp(
|
||||
in_features=dim_in,
|
||||
hidden_features=dim_out // 2,
|
||||
out_features=dim_out,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, extrinsics: torch.Tensor, intrinsics: torch.Tensor,
|
||||
image_size_hw) -> torch.Tensor:
|
||||
"""Encode camera parameters into ``(B, S, dim_out)`` tokens."""
|
||||
c2ws = affine_inverse(extrinsics)
|
||||
pose_encoding = extri_intri_to_pose_encoding(c2ws, intrinsics, image_size_hw)
|
||||
tokens = self.pose_branch(pose_encoding.to(self.pose_branch.fc1.weight.dtype))
|
||||
tokens = self.token_norm(tokens)
|
||||
tokens = self.trunk(tokens)
|
||||
tokens = self.trunk_norm(tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
class CameraDec(nn.Module):
|
||||
"""Decode the final cam token into a 9-D pose encoding.
|
||||
|
||||
Output layout: ``[T(3), quat_xyzw(4), fov_h, fov_w]``. The translation is
|
||||
always predicted by the network; the quaternion and FoV can either be
|
||||
predicted or supplied via ``camera_encoding`` (used at training time
|
||||
when GT cameras are available -- not exercised at inference here).
|
||||
|
||||
Parameters mirror the upstream ``cam_dec.py`` so HF weights load directly.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int = 1536,
|
||||
*, device=None, dtype=None, operations=None, **_kwargs):
|
||||
super().__init__()
|
||||
d = dim_in
|
||||
self.backbone = nn.Sequential(
|
||||
operations.Linear(d, d, device=device, dtype=dtype),
|
||||
nn.ReLU(),
|
||||
operations.Linear(d, d, device=device, dtype=dtype),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.fc_t = operations.Linear(d, 3, device=device, dtype=dtype)
|
||||
self.fc_qvec = operations.Linear(d, 4, device=device, dtype=dtype)
|
||||
self.fc_fov = nn.Sequential(
|
||||
operations.Linear(d, 2, device=device, dtype=dtype),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, feat: torch.Tensor,
|
||||
camera_encoding: "torch.Tensor | None" = None) -> torch.Tensor:
|
||||
"""Decode ``(B, N, dim_in)`` cam tokens into ``(B, N, 9)`` pose enc."""
|
||||
B, N = feat.shape[:2]
|
||||
feat = feat.reshape(B * N, -1)
|
||||
feat = self.backbone(feat)
|
||||
out_t = self.fc_t(feat.float()).reshape(B, N, 3)
|
||||
if camera_encoding is None:
|
||||
out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4)
|
||||
out_fov = self.fc_fov(feat.float()).reshape(B, N, 2)
|
||||
else:
|
||||
out_qvec = camera_encoding[..., 3:7]
|
||||
out_fov = camera_encoding[..., -2:]
|
||||
return torch.cat([out_t, out_qvec, out_fov], dim=-1)
|
||||
549
comfy/ldm/depth_anything_3/dpt.py
Normal file
549
comfy/ldm/depth_anything_3/dpt.py
Normal file
@ -0,0 +1,549 @@
|
||||
# DPT / DualDPT heads for Depth Anything 3.
|
||||
#
|
||||
# Ported from:
|
||||
# src/depth_anything_3/model/dpt.py (DPT - single main head + sky head)
|
||||
# src/depth_anything_3/model/dualdpt.py (DualDPT - depth + auxiliary "ray" head)
|
||||
#
|
||||
# In the monocular path we always discard the auxiliary "ray" output of
|
||||
# DualDPT. The auxiliary branch is still constructed so that DA3 HF weights
|
||||
# load cleanly without missing-key warnings.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helpers (matching upstream head_utils.py)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Permute(nn.Module):
|
||||
def __init__(self, dims: Tuple[int, ...]):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.permute(*self.dims)
|
||||
|
||||
|
||||
def _custom_interpolate(
|
||||
x: torch.Tensor,
|
||||
size: Optional[Tuple[int, int]] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
mode: str = "bilinear",
|
||||
align_corners: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if size is None:
|
||||
assert scale_factor is not None
|
||||
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
||||
INT_MAX = 1610612736
|
||||
total = size[0] * size[1] * x.shape[0] * x.shape[1]
|
||||
if total > INT_MAX:
|
||||
chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0)
|
||||
outs = [F.interpolate(c, size=size, mode=mode, align_corners=align_corners) for c in chunks]
|
||||
return torch.cat(outs, dim=0).contiguous()
|
||||
return F.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
||||
|
||||
|
||||
def _create_uv_grid(width: int, height: int, aspect_ratio: float,
|
||||
dtype, device) -> torch.Tensor:
|
||||
"""Normalised UV grid spanning (-x_span, -y_span)..(x_span, y_span)."""
|
||||
diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5
|
||||
span_x = aspect_ratio / diag_factor
|
||||
span_y = 1.0 / diag_factor
|
||||
left_x = -span_x * (width - 1) / width
|
||||
right_x = span_x * (width - 1) / width
|
||||
top_y = -span_y * (height - 1) / height
|
||||
bottom_y = span_y * (height - 1) / height
|
||||
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
||||
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
||||
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
||||
return torch.stack((uu, vv), dim=-1) # (H, W, 2)
|
||||
|
||||
|
||||
def _make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100.0) -> torch.Tensor:
|
||||
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
|
||||
omega = 1.0 / omega_0 ** (omega / (embed_dim / 2.0))
|
||||
pos = pos.reshape(-1)
|
||||
out = torch.einsum("m,d->md", pos, omega)
|
||||
return torch.cat([out.sin(), out.cos()], dim=1).float()
|
||||
|
||||
|
||||
def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int,
|
||||
omega_0: float = 100.0) -> torch.Tensor:
|
||||
H, W, _ = pos_grid.shape
|
||||
pos_flat = pos_grid.reshape(-1, 2)
|
||||
emb_x = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0)
|
||||
emb_y = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0)
|
||||
emb = torch.cat([emb_x, emb_y], dim=-1)
|
||||
return emb.view(H, W, embed_dim)
|
||||
|
||||
|
||||
def _add_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
||||
"""Stateless UV positional embedding added to a feature map (B, C, h, w)."""
|
||||
pw, ph = x.shape[-1], x.shape[-2]
|
||||
pe = _create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
||||
pe = _position_grid_to_embed(pe, x.shape[1]) * ratio
|
||||
pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1).to(dtype=x.dtype)
|
||||
return x + pe
|
||||
|
||||
|
||||
def _apply_activation(x: torch.Tensor, activation: str) -> torch.Tensor:
|
||||
act = (activation or "linear").lower()
|
||||
if act == "exp":
|
||||
return torch.exp(x)
|
||||
if act == "expp1":
|
||||
return torch.exp(x) + 1
|
||||
if act == "expm1":
|
||||
return torch.expm1(x)
|
||||
if act == "relu":
|
||||
return torch.relu(x)
|
||||
if act == "sigmoid":
|
||||
return torch.sigmoid(x)
|
||||
if act == "softplus":
|
||||
return F.softplus(x)
|
||||
if act == "tanh":
|
||||
return torch.tanh(x)
|
||||
return x
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Fusion building blocks
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
def __init__(self, features: int,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
self.activation = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
def __init__(self, features: int, has_residual: bool = True,
|
||||
align_corners: bool = True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.align_corners = align_corners
|
||||
self.has_residual = has_residual
|
||||
if has_residual:
|
||||
self.resConfUnit1 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
self.resConfUnit1 = None
|
||||
self.resConfUnit2 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
|
||||
self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
|
||||
def forward(self, *xs: torch.Tensor, size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
|
||||
y = xs[0]
|
||||
if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
|
||||
y = y + self.resConfUnit1(xs[1])
|
||||
y = self.resConfUnit2(y)
|
||||
if size is None:
|
||||
up_kwargs = {"scale_factor": 2.0}
|
||||
else:
|
||||
up_kwargs = {"size": size}
|
||||
y = _custom_interpolate(y, **up_kwargs, mode="bilinear",
|
||||
align_corners=self.align_corners)
|
||||
y = self.out_conv(y)
|
||||
return y
|
||||
|
||||
|
||||
class _Scratch(nn.Module):
|
||||
"""Container that mirrors upstream ``scratch`` attribute layout."""
|
||||
|
||||
|
||||
def _make_scratch(in_shape: List[int], out_shape: int,
|
||||
device=None, dtype=None, operations=None) -> _Scratch:
|
||||
scratch = _Scratch()
|
||||
scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False,
|
||||
device=device, dtype=dtype)
|
||||
scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False,
|
||||
device=device, dtype=dtype)
|
||||
scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False,
|
||||
device=device, dtype=dtype)
|
||||
scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False,
|
||||
device=device, dtype=dtype)
|
||||
return scratch
|
||||
|
||||
|
||||
def _make_fusion_block(features: int, has_residual: bool = True,
|
||||
device=None, dtype=None, operations=None) -> FeatureFusionBlock:
|
||||
return FeatureFusionBlock(features, has_residual=has_residual,
|
||||
align_corners=True,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DPT (single head + optional sky head) -- used by DA3Mono/Metric
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DPT(nn.Module):
|
||||
"""Single-head DPT used by DA3Mono-Large and DA3Metric-Large."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
patch_size: int = 14,
|
||||
output_dim: int = 1,
|
||||
activation: str = "exp",
|
||||
conf_activation: str = "expp1",
|
||||
features: int = 256,
|
||||
out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
pos_embed: bool = False,
|
||||
down_ratio: int = 1,
|
||||
head_name: str = "depth",
|
||||
use_sky_head: bool = True,
|
||||
sky_name: str = "sky",
|
||||
sky_activation: str = "relu",
|
||||
norm_type: str = "idt",
|
||||
device=None, dtype=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.activation = activation
|
||||
self.conf_activation = conf_activation
|
||||
self.pos_embed = pos_embed
|
||||
self.down_ratio = down_ratio
|
||||
self.head_main = head_name
|
||||
self.sky_name = sky_name
|
||||
self.out_dim = output_dim
|
||||
self.has_conf = output_dim > 1
|
||||
self.use_sky_head = use_sky_head
|
||||
self.sky_activation = sky_activation
|
||||
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
|
||||
|
||||
if norm_type == "layer":
|
||||
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
out_channels = list(out_channels)
|
||||
self.projects = nn.ModuleList([
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype)
|
||||
for oc in out_channels
|
||||
])
|
||||
self.resize_layers = nn.ModuleList([
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
nn.Identity(),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
])
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
self.scratch.output_conv1 = operations.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
if self.use_sky_head:
|
||||
self.scratch.sky_output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int,
|
||||
patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
# feats[i][0] is the patch-token tensor with shape (B, S, N_patch, C)
|
||||
B, S, N, C = feats[0][0].shape
|
||||
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
|
||||
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
resized = []
|
||||
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
|
||||
x = feats_flat[take_idx][:, patch_start_idx:]
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
|
||||
x = self.projects[stage_idx](x)
|
||||
if self.pos_embed:
|
||||
x = _add_pos_embed(x, W, H)
|
||||
x = self.resize_layers[stage_idx](x)
|
||||
resized.append(x)
|
||||
|
||||
l1_rn = self.scratch.layer1_rn(resized[0])
|
||||
l2_rn = self.scratch.layer2_rn(resized[1])
|
||||
l3_rn = self.scratch.layer3_rn(resized[2])
|
||||
l4_rn = self.scratch.layer4_rn(resized[3])
|
||||
|
||||
out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
|
||||
out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
|
||||
out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
|
||||
out = self.scratch.refinenet1(out, l1_rn)
|
||||
|
||||
h_out = int(ph * self.patch_size / self.down_ratio)
|
||||
w_out = int(pw * self.patch_size / self.down_ratio)
|
||||
|
||||
fused = self.scratch.output_conv1(out)
|
||||
fused = _custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
|
||||
if self.pos_embed:
|
||||
fused = _add_pos_embed(fused, W, H)
|
||||
feat = fused
|
||||
|
||||
main_logits = self.scratch.output_conv2(feat)
|
||||
outs = {}
|
||||
if self.has_conf:
|
||||
fmap = main_logits.permute(0, 2, 3, 1)
|
||||
pred = _apply_activation(fmap[..., :-1], self.activation)
|
||||
conf = _apply_activation(fmap[..., -1], self.conf_activation)
|
||||
outs[self.head_main] = pred.squeeze(-1).view(B, S, *pred.shape[1:-1])
|
||||
outs[f"{self.head_main}_conf"] = conf.view(B, S, *conf.shape[1:])
|
||||
else:
|
||||
pred = _apply_activation(main_logits, self.activation)
|
||||
outs[self.head_main] = pred.squeeze(1).view(B, S, *pred.shape[2:])
|
||||
|
||||
if self.use_sky_head:
|
||||
sky_logits = self.scratch.sky_output_conv2(feat)
|
||||
if self.sky_activation.lower() == "sigmoid":
|
||||
sky = torch.sigmoid(sky_logits)
|
||||
elif self.sky_activation.lower() == "relu":
|
||||
sky = F.relu(sky_logits)
|
||||
else:
|
||||
sky = sky_logits
|
||||
outs[self.sky_name] = sky.squeeze(1).view(B, S, *sky.shape[2:])
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DualDPT (depth + auxiliary "ray" head) -- used by DA3-Small / DA3-Base
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DualDPT(nn.Module):
|
||||
"""Two-head DPT used by DA3-Small / DA3-Base.
|
||||
|
||||
The auxiliary "ray" head is constructed so that HF state-dict keys load
|
||||
cleanly. It is only executed when :attr:`enable_aux` is set on the
|
||||
instance (typically by ``DepthAnything3Net`` when running multi-view
|
||||
with ``use_ray_pose=True``); otherwise the monocular path skips it for
|
||||
speed and the auxiliary submodules sit idle.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
patch_size: int = 14,
|
||||
output_dim: int = 2,
|
||||
activation: str = "exp",
|
||||
conf_activation: str = "expp1",
|
||||
features: int = 256,
|
||||
out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
pos_embed: bool = True,
|
||||
down_ratio: int = 1,
|
||||
aux_pyramid_levels: int = 4,
|
||||
aux_out1_conv_num: int = 5,
|
||||
head_names: Tuple[str, str] = ("depth", "ray"),
|
||||
device=None, dtype=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.activation = activation
|
||||
self.conf_activation = conf_activation
|
||||
self.pos_embed = pos_embed
|
||||
self.down_ratio = down_ratio
|
||||
self.aux_levels = aux_pyramid_levels
|
||||
self.aux_out1_conv_num = aux_out1_conv_num
|
||||
self.head_main, self.head_aux = head_names
|
||||
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
|
||||
# Toggle the auxiliary ray branch at runtime. Default off (mono path).
|
||||
# ``DepthAnything3Net`` flips this on when running multi-view + ray-pose.
|
||||
self.enable_aux: bool = False
|
||||
|
||||
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
|
||||
out_channels = list(out_channels)
|
||||
self.projects = nn.ModuleList([
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype)
|
||||
for oc in out_channels
|
||||
])
|
||||
self.resize_layers = nn.ModuleList([
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
nn.Identity(),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
])
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
# Main fusion chain
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
# Auxiliary fusion chain (separate copies)
|
||||
self.scratch.refinenet1_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
# Main head neck + final projection
|
||||
self.scratch.output_conv1 = operations.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
# Aux pre-head per level (multi-level pyramid)
|
||||
self.scratch.output_conv1_aux = nn.ModuleList([
|
||||
self._make_aux_out1_block(head_features_1, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(self.aux_levels)
|
||||
])
|
||||
|
||||
# Aux final projection per level (includes LayerNorm permute path).
|
||||
ln_seq = [Permute((0, 2, 3, 1)),
|
||||
operations.LayerNorm(head_features_2, device=device, dtype=dtype),
|
||||
Permute((0, 3, 1, 2))]
|
||||
self.scratch.output_conv2_aux = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
*ln_seq,
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
)
|
||||
for _ in range(self.aux_levels)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def _make_aux_out1_block(in_ch: int, *, device=None, dtype=None, operations=None) -> nn.Sequential:
|
||||
# aux_out1_conv_num=5 in all Apache-2.0 variants.
|
||||
return nn.Sequential(
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int,
|
||||
patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
B, S, N, C = feats[0][0].shape
|
||||
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
|
||||
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
resized = []
|
||||
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
|
||||
x = feats_flat[take_idx][:, patch_start_idx:]
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
|
||||
x = self.projects[stage_idx](x)
|
||||
if self.pos_embed:
|
||||
x = _add_pos_embed(x, W, H)
|
||||
x = self.resize_layers[stage_idx](x)
|
||||
resized.append(x)
|
||||
|
||||
l1_rn = self.scratch.layer1_rn(resized[0])
|
||||
l2_rn = self.scratch.layer2_rn(resized[1])
|
||||
l3_rn = self.scratch.layer3_rn(resized[2])
|
||||
l4_rn = self.scratch.layer4_rn(resized[3])
|
||||
|
||||
# Main pyramid (output_conv1 is applied inside the upstream `_fuse`,
|
||||
# before interpolation -- replicate that order here).
|
||||
m = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
|
||||
if self.enable_aux:
|
||||
a4 = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:])
|
||||
aux_pyr = [a4]
|
||||
m = self.scratch.refinenet3(m, l3_rn, size=l2_rn.shape[2:])
|
||||
if self.enable_aux:
|
||||
aux_pyr.append(self.scratch.refinenet3_aux(aux_pyr[-1], l3_rn, size=l2_rn.shape[2:]))
|
||||
m = self.scratch.refinenet2(m, l2_rn, size=l1_rn.shape[2:])
|
||||
if self.enable_aux:
|
||||
aux_pyr.append(self.scratch.refinenet2_aux(aux_pyr[-1], l2_rn, size=l1_rn.shape[2:]))
|
||||
m = self.scratch.refinenet1(m, l1_rn)
|
||||
if self.enable_aux:
|
||||
aux_pyr.append(self.scratch.refinenet1_aux(aux_pyr[-1], l1_rn))
|
||||
m = self.scratch.output_conv1(m)
|
||||
|
||||
h_out = int(ph * self.patch_size / self.down_ratio)
|
||||
w_out = int(pw * self.patch_size / self.down_ratio)
|
||||
|
||||
m = _custom_interpolate(m, (h_out, w_out), mode="bilinear", align_corners=True)
|
||||
if self.pos_embed:
|
||||
m = _add_pos_embed(m, W, H)
|
||||
main_logits = self.scratch.output_conv2(m)
|
||||
fmap = main_logits.permute(0, 2, 3, 1)
|
||||
depth_pred = _apply_activation(fmap[..., :-1], self.activation)
|
||||
depth_conf = _apply_activation(fmap[..., -1], self.conf_activation)
|
||||
|
||||
outs = {
|
||||
self.head_main: depth_pred.squeeze(-1).view(B, S, *depth_pred.shape[1:-1]),
|
||||
f"{self.head_main}_conf": depth_conf.view(B, S, *depth_conf.shape[1:]),
|
||||
}
|
||||
|
||||
if self.enable_aux:
|
||||
# Auxiliary "ray" head (multi-level inside) -- only the last level
|
||||
# is returned. Mirrors upstream ``DualDPT._fuse`` + ``_forward_impl``:
|
||||
# each aux pyramid level goes through ``output_conv1_aux[i]``
|
||||
# (5-layer conv stack that ends at ``features // 2`` channels),
|
||||
# then the last level optionally gets a pos-embed and finally
|
||||
# ``output_conv2_aux[-1]``.
|
||||
aux_processed = [
|
||||
self.scratch.output_conv1_aux[i](a) for i, a in enumerate(aux_pyr)
|
||||
]
|
||||
last_aux = aux_processed[-1]
|
||||
if self.pos_embed:
|
||||
last_aux = _add_pos_embed(last_aux, W, H)
|
||||
last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux)
|
||||
fmap_last = last_aux_logits.permute(0, 2, 3, 1)
|
||||
# Channels: [ray(6), ray_conf(1)]; ray uses 'linear' activation.
|
||||
aux_pred = fmap_last[..., :-1]
|
||||
aux_conf = _apply_activation(fmap_last[..., -1], self.conf_activation)
|
||||
outs[self.head_aux] = aux_pred.view(B, S, *aux_pred.shape[1:])
|
||||
outs[f"{self.head_aux}_conf"] = aux_conf.view(B, S, *aux_conf.shape[1:])
|
||||
|
||||
return outs
|
||||
300
comfy/ldm/depth_anything_3/model.py
Normal file
300
comfy/ldm/depth_anything_3/model.py
Normal file
@ -0,0 +1,300 @@
|
||||
# DepthAnything3Net: top-level wrapper that combines backbone + head.
|
||||
#
|
||||
# Supports both the monocular and the multi-view + camera path:
|
||||
#
|
||||
# * Monocular: ``S = 1``, no camera encoder/decoder. Mirrors the original
|
||||
# port that only handled ``DA3-MONO/METRIC-LARGE`` and the auxiliary-disabled
|
||||
# ``DA3-SMALL/BASE`` configs.
|
||||
# * Multi-view + camera: ``S > 1``. ``cam_enc`` (optional) maps user-supplied
|
||||
# extrinsics + intrinsics into a per-view camera token; ``cam_dec`` decodes
|
||||
# the final layer's camera token into a 9-D pose encoding. When the
|
||||
# auxiliary "ray" head of ``DualDPT`` is enabled the predicted ray map can
|
||||
# alternatively be used to estimate pose via RANSAC (``use_ray_pose=True``).
|
||||
# The 3D-Gaussian head and the nested-architecture wrapper are intentionally
|
||||
# left out of scope here; their state-dict keys are filtered in
|
||||
# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``.
|
||||
#
|
||||
# The backbone is shared with the CLIP-vision DINOv2 path
|
||||
# (``comfy.image_encoders.dino2.Dinov2Model``); the DA3-specific extensions
|
||||
# (RoPE, QK-norm, alternating local/global attention, camera token, multi-
|
||||
# layer feature extraction, reference-view reordering) are opt-in via the
|
||||
# config dict and are all disabled for the Mono/Metric variants.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.image_encoders.dino2 import Dinov2Model
|
||||
|
||||
from .camera import CameraDec, CameraEnc
|
||||
from .dpt import DPT, DualDPT
|
||||
from .ray_pose import get_extrinsic_from_camray
|
||||
from .transform import affine_inverse, pose_encoding_to_extri_intri
|
||||
|
||||
|
||||
_HEAD_REGISTRY = {
|
||||
"dpt": DPT,
|
||||
"dualdpt": DualDPT,
|
||||
}
|
||||
|
||||
|
||||
# Backbone presets (mirror the upstream DINOv2 ViT variants).
|
||||
_BACKBONE_PRESETS = {
|
||||
"vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False),
|
||||
"vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False),
|
||||
"vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False),
|
||||
"vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True),
|
||||
}
|
||||
|
||||
|
||||
def _build_backbone_config(
|
||||
backbone_name: str,
|
||||
*,
|
||||
alt_start: int,
|
||||
qknorm_start: int,
|
||||
rope_start: int,
|
||||
cat_token: bool,
|
||||
) -> dict:
|
||||
if backbone_name not in _BACKBONE_PRESETS:
|
||||
raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}")
|
||||
cfg = dict(_BACKBONE_PRESETS[backbone_name])
|
||||
cfg.update(dict(
|
||||
layer_norm_eps=1e-6,
|
||||
patch_size=14,
|
||||
image_size=518,
|
||||
# No mask_token in DA3 weights; omit param to avoid load warnings.
|
||||
use_mask_token=False,
|
||||
alt_start=alt_start,
|
||||
qknorm_start=qknorm_start,
|
||||
rope_start=rope_start,
|
||||
cat_token=cat_token,
|
||||
rope_freq=100.0,
|
||||
))
|
||||
return cfg
|
||||
|
||||
|
||||
class DepthAnything3Net(nn.Module):
|
||||
"""ComfyUI-side DepthAnything3 network.
|
||||
|
||||
Parameters mirror the variant YAML configs from the upstream repo and
|
||||
are auto-detected from the state dict by ``comfy/model_detection.py``.
|
||||
The kwargs ``device``, ``dtype`` and ``operations`` are injected by
|
||||
``BaseModel``.
|
||||
"""
|
||||
|
||||
PATCH_SIZE = 14
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# --- Backbone ---
|
||||
backbone_name: str = "vitl",
|
||||
out_layers: Sequence[int] = (4, 11, 17, 23),
|
||||
alt_start: int = -1,
|
||||
qknorm_start: int = -1,
|
||||
rope_start: int = -1,
|
||||
cat_token: bool = False,
|
||||
# --- Head ---
|
||||
head_type: str = "dpt", # "dpt" or "dualdpt"
|
||||
head_dim_in: int = 1024,
|
||||
head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf
|
||||
head_features: int = 256,
|
||||
head_out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
head_use_sky_head: bool = True, # ignored by DualDPT
|
||||
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
|
||||
# --- Camera (multi-view) ---
|
||||
has_cam_enc: bool = False,
|
||||
has_cam_dec: bool = False,
|
||||
cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim)
|
||||
cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token)
|
||||
# ComfyUI plumbing
|
||||
device=None, dtype=None, operations=None,
|
||||
**_ignored,
|
||||
):
|
||||
super().__init__()
|
||||
head_cls = _HEAD_REGISTRY[head_type.lower()]
|
||||
self.head_type = head_type.lower()
|
||||
self.has_sky = (self.head_type == "dpt") and head_use_sky_head
|
||||
self.has_conf = head_output_dim > 1
|
||||
self.out_layers = list(out_layers)
|
||||
|
||||
backbone_cfg = _build_backbone_config(
|
||||
backbone_name,
|
||||
alt_start=alt_start,
|
||||
qknorm_start=qknorm_start,
|
||||
rope_start=rope_start,
|
||||
cat_token=cat_token,
|
||||
)
|
||||
self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations)
|
||||
|
||||
head_kwargs = dict(
|
||||
dim_in=head_dim_in,
|
||||
patch_size=self.PATCH_SIZE,
|
||||
output_dim=head_output_dim,
|
||||
features=head_features,
|
||||
out_channels=tuple(head_out_channels),
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
if self.head_type == "dpt":
|
||||
head_kwargs.update(
|
||||
use_sky_head=head_use_sky_head,
|
||||
pos_embed=(False if head_pos_embed is None else head_pos_embed),
|
||||
)
|
||||
else: # dualdpt
|
||||
head_kwargs.update(
|
||||
pos_embed=(True if head_pos_embed is None else head_pos_embed),
|
||||
)
|
||||
self.head = head_cls(**head_kwargs)
|
||||
|
||||
# Built only if checkpoint has weights; cam_enc output dim == embed_dim.
|
||||
embed_dim = backbone_cfg["hidden_size"]
|
||||
if has_cam_enc:
|
||||
self.cam_enc = CameraEnc(
|
||||
dim_out=cam_dim_out if cam_dim_out is not None else embed_dim,
|
||||
num_heads=max(1, embed_dim // 64),
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
else:
|
||||
self.cam_enc = None
|
||||
if has_cam_dec:
|
||||
default_dim = embed_dim * (2 if cat_token else 1)
|
||||
self.cam_dec = CameraDec(
|
||||
dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
else:
|
||||
self.cam_dec = None
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
extrinsics: Optional[torch.Tensor] = None,
|
||||
intrinsics: Optional[torch.Tensor] = None,
|
||||
*,
|
||||
use_ray_pose: bool = False,
|
||||
ref_view_strategy: str = "saddle_balanced",
|
||||
export_feat_layers: Optional[Sequence[int]] = None,
|
||||
**_unused,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Run depth (and optionally pose) prediction.
|
||||
|
||||
Args:
|
||||
image: ``(B, 3, H, W)`` ImageNet-normalised image tensor, or
|
||||
``(B, S, 3, H, W)`` for multi-view inputs. ``H`` and ``W``
|
||||
must be multiples of 14.
|
||||
extrinsics: optional ``(B, S, 4, 4)`` world-to-camera extrinsics.
|
||||
When provided together with ``intrinsics``, ``CameraEnc``
|
||||
converts them into per-view camera tokens that the backbone
|
||||
injects at block ``alt_start``.
|
||||
intrinsics: optional ``(B, S, 3, 3)`` pixel-space intrinsics.
|
||||
use_ray_pose: if True, predict pose from the auxiliary "ray" head
|
||||
(RANSAC over per-pixel rays). Only available on DualDPT
|
||||
variants. If False (default) and ``cam_dec`` is present,
|
||||
the final-layer cam token is decoded into pose instead.
|
||||
ref_view_strategy: reference-view selection strategy used when
|
||||
``S >= 3`` and no extrinsics are supplied. See
|
||||
:mod:`comfy.ldm.depth_anything_3.reference_view_selector`.
|
||||
export_feat_layers: optional list of backbone layer indices whose
|
||||
local features to also return as auxiliary outputs (used by
|
||||
downstream nested-architecture wrappers; empty by default).
|
||||
|
||||
Returns:
|
||||
Dict with a subset of:
|
||||
- ``depth`` ``(B*S, H, W)`` raw depth values.
|
||||
- ``depth_conf`` ``(B*S, H, W)`` confidence (DualDPT only).
|
||||
- ``sky`` ``(B*S, H, W)`` sky probability (DPT + sky head).
|
||||
- ``ray`` ``(B, S, h, w, 6)`` per-pixel cam ray (DualDPT,
|
||||
multi-view, ``use_ray_pose=True`` only).
|
||||
- ``ray_conf`` ``(B, S, h, w)`` ray confidence.
|
||||
- ``extrinsics`` ``(B, S, 4, 4)`` world-to-cam, when pose
|
||||
prediction is active.
|
||||
- ``intrinsics`` ``(B, S, 3, 3)`` pixel-space intrinsics.
|
||||
- ``aux_features`` list of ``(B, S, h_p, w_p, C)`` features
|
||||
when ``export_feat_layers`` is non-empty.
|
||||
"""
|
||||
if image.ndim == 4:
|
||||
image = image.unsqueeze(1) # (B, 1, 3, H, W)
|
||||
assert image.ndim == 5 and image.shape[2] == 3, \
|
||||
f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}"
|
||||
|
||||
B, S, _, H, W = image.shape
|
||||
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
|
||||
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
|
||||
|
||||
# Camera-token preparation (multi-view path).
|
||||
cam_token = None
|
||||
if extrinsics is not None and intrinsics is not None and self.cam_enc is not None:
|
||||
cam_token = self.cam_enc(extrinsics, intrinsics, (H, W))
|
||||
|
||||
# Toggle aux ray output on/off depending on what the caller asked for.
|
||||
if isinstance(self.head, DualDPT):
|
||||
self.head.enable_aux = bool(use_ray_pose)
|
||||
|
||||
feats, aux_feats = self.backbone.get_intermediate_layers_da3(
|
||||
image, self.out_layers, cam_token=cam_token,
|
||||
ref_view_strategy=ref_view_strategy,
|
||||
export_feat_layers=export_feat_layers,
|
||||
)
|
||||
head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
|
||||
|
||||
# Pose prediction.
|
||||
out: Dict[str, torch.Tensor] = {}
|
||||
if use_ray_pose and "ray" in head_out and "ray_conf" in head_out:
|
||||
ray = head_out["ray"]
|
||||
ray_conf = head_out["ray_conf"]
|
||||
extr_c2w, focal, pp = get_extrinsic_from_camray(
|
||||
ray, ray_conf, ray.shape[-3], ray.shape[-2],
|
||||
)
|
||||
# Match the upstream output: w2c, drop the homogeneous row.
|
||||
extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :]
|
||||
# Build pixel-space intrinsics from the normalised focal/pp output.
|
||||
intr = torch.eye(3, device=ray.device, dtype=ray.dtype)
|
||||
intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone()
|
||||
intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W
|
||||
intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H
|
||||
intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5
|
||||
intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5
|
||||
out["extrinsics"] = extr_w2c
|
||||
out["intrinsics"] = intr
|
||||
elif self.cam_dec is not None and S > 1:
|
||||
# Decode the cam-token of the final out_layer into a pose encoding.
|
||||
cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec)
|
||||
pose_enc = self.cam_dec(cam_feat)
|
||||
c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W))
|
||||
# Match the upstream output convention: w2c (world->camera), 3x4.
|
||||
c2w_4x4 = torch.cat([
|
||||
c2w_3x4,
|
||||
torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype)
|
||||
.view(1, 1, 1, 4).expand(B, S, 1, 4),
|
||||
], dim=-2)
|
||||
out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :]
|
||||
out["intrinsics"] = intr
|
||||
|
||||
# Flatten the views axis for per-pixel outputs (depth/conf/sky) so the
|
||||
# per-image consumer keeps its (B*S, H, W) interface.
|
||||
for k, v in head_out.items():
|
||||
if k in ("ray", "ray_conf"):
|
||||
# Keep multi-view shape for downstream pose work.
|
||||
out[k] = v
|
||||
elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S:
|
||||
out[k] = v.reshape(B * S, *v.shape[2:])
|
||||
else:
|
||||
out[k] = v
|
||||
|
||||
if export_feat_layers:
|
||||
out["aux_features"] = self._reshape_aux_features(aux_feats, H, W)
|
||||
return out
|
||||
|
||||
def _reshape_aux_features(self, aux_feats, H: int, W: int):
|
||||
"""Reshape ``(B, S, N, C)`` aux features into ``(B, S, h_p, w_p, C)``."""
|
||||
ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE
|
||||
out = []
|
||||
for f in aux_feats:
|
||||
B, S, N, C = f.shape
|
||||
assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}"
|
||||
out.append(f.reshape(B, S, ph, pw, C))
|
||||
return out
|
||||
184
comfy/ldm/depth_anything_3/preprocess.py
Normal file
184
comfy/ldm/depth_anything_3/preprocess.py
Normal file
@ -0,0 +1,184 @@
|
||||
# Input/output preprocessing helpers for Depth Anything 3.
|
||||
#
|
||||
# Ported from:
|
||||
# src/depth_anything_3/utils/io/input_processor.py (image normalisation)
|
||||
# src/depth_anything_3/utils/alignment.py (sky-aware depth clip)
|
||||
# src/depth_anything_3/model/da3.py::_process_mono_sky_estimation
|
||||
#
|
||||
# Resize: ``comfy.utils.common_upscale`` with ``upscale_method="lanczos"``.
|
||||
# Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale); a sweep
|
||||
# across {bilinear, bicubic, area, lanczos, bislerp} on a 768->504 test image
|
||||
# showed lanczos has the lowest max-abs-diff vs the upstream cv2 output
|
||||
# (~0.13 vs 0.21-0.71 for the others), so we use it in both directions for
|
||||
# simplicity. This keeps the path stateless, on-device, and free of any
|
||||
# OpenCV dependency.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
|
||||
PATCH_SIZE = 14
|
||||
|
||||
# ImageNet normalization constants used during DA3 training.
|
||||
_IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406])
|
||||
_IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int:
|
||||
down = (x // patch) * patch
|
||||
up = down + patch
|
||||
return up if abs(up - x) <= abs(x - down) else down
|
||||
|
||||
|
||||
def compute_target_size(orig_h: int, orig_w: int, process_res: int,
|
||||
method: str = "upper_bound_resize") -> Tuple[int, int]:
|
||||
"""Compute (target_h, target_w) for a single image.
|
||||
|
||||
Methods:
|
||||
- "upper_bound_resize": scale longest side to ``process_res``, then
|
||||
round each dim to nearest multiple of 14 (default upstream method).
|
||||
- "lower_bound_resize": scale shortest side to ``process_res``, then
|
||||
round.
|
||||
"""
|
||||
if method == "upper_bound_resize":
|
||||
longest = max(orig_h, orig_w)
|
||||
scale = process_res / float(longest)
|
||||
elif method == "lower_bound_resize":
|
||||
shortest = min(orig_h, orig_w)
|
||||
scale = process_res / float(shortest)
|
||||
else:
|
||||
raise ValueError(f"Unsupported process_res_method: {method}")
|
||||
|
||||
new_w = max(1, _round_to_patch(int(round(orig_w * scale))))
|
||||
new_h = max(1, _round_to_patch(int(round(orig_h * scale))))
|
||||
return new_h, new_w
|
||||
|
||||
|
||||
def preprocess_image(
|
||||
image: torch.Tensor,
|
||||
process_res: int = 504,
|
||||
method: str = "upper_bound_resize",
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a ComfyUI ``IMAGE`` batch for DA3.
|
||||
|
||||
Args:
|
||||
image: ``(B, H, W, 3)`` float in [0, 1] (ComfyUI ``IMAGE`` convention).
|
||||
process_res: target resolution (longest or shortest side, depending
|
||||
on ``method``).
|
||||
method: resize strategy.
|
||||
|
||||
Returns:
|
||||
``(B, 3, H', W')`` tensor with H' and W' multiples of 14, normalised
|
||||
with ImageNet statistics. The tensor lives on the same device as
|
||||
``image``.
|
||||
"""
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, \
|
||||
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
B, H, W, _ = image.shape
|
||||
target_h, target_w = compute_target_size(H, W, process_res, method)
|
||||
|
||||
# (B, H, W, 3) -> (B, 3, H, W)
|
||||
x = image.movedim(-1, 1).contiguous()
|
||||
if (target_h, target_w) != (H, W):
|
||||
# Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale).
|
||||
# Lanczos in ``common_upscale`` is anti-aliased and produces the
|
||||
# closest pixel-wise match in a sweep across {bilinear, bicubic,
|
||||
# area, lanczos, bislerp}. Used in both directions for simplicity.
|
||||
x = comfy.utils.common_upscale(
|
||||
x.float(), target_w, target_h, "lanczos", "disabled",
|
||||
)
|
||||
x = x.clamp(0.0, 1.0)
|
||||
|
||||
mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||
std = _IMAGENET_STD.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||
x = (x - mean) / std
|
||||
return x
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Output post-processing (sky-aware clipping for Mono/Metric variants)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor:
|
||||
"""Boolean mask: True for non-sky pixels (sky probability < threshold)."""
|
||||
return sky_prediction < threshold
|
||||
|
||||
|
||||
def apply_sky_aware_clip(
|
||||
depth: torch.Tensor,
|
||||
sky: torch.Tensor,
|
||||
threshold: float = 0.3,
|
||||
quantile: float = 0.99,
|
||||
) -> torch.Tensor:
|
||||
"""Replicates ``_process_mono_sky_estimation`` from upstream.
|
||||
|
||||
Clips sky regions to the 99th percentile of non-sky depth. Returns a new
|
||||
depth tensor; ``depth`` is not modified in place.
|
||||
"""
|
||||
non_sky = compute_non_sky_mask(sky, threshold=threshold)
|
||||
if non_sky.sum() <= 10 or (~non_sky).sum() <= 10:
|
||||
return depth.clone()
|
||||
|
||||
non_sky_depth = depth[non_sky]
|
||||
if non_sky_depth.numel() > 100_000:
|
||||
idx = torch.randint(0, non_sky_depth.numel(), (100_000,), device=non_sky_depth.device)
|
||||
sampled = non_sky_depth[idx]
|
||||
else:
|
||||
sampled = non_sky_depth
|
||||
|
||||
max_depth = torch.quantile(sampled, quantile)
|
||||
out = depth.clone()
|
||||
out[~non_sky] = max_depth
|
||||
return out
|
||||
|
||||
|
||||
def normalize_depth_v2_style(
|
||||
depth: torch.Tensor,
|
||||
sky: torch.Tensor | None = None,
|
||||
low_quantile: float = 0.01,
|
||||
high_quantile: float = 0.99,
|
||||
) -> torch.Tensor:
|
||||
"""V2-style normalization for ControlNet workflows.
|
||||
|
||||
Computes percentile bounds over non-sky pixels (when available),
|
||||
then maps depth into [0, 1] with near = white (1.0).
|
||||
"""
|
||||
if sky is not None:
|
||||
mask = compute_non_sky_mask(sky)
|
||||
if mask.any():
|
||||
valid = depth[mask]
|
||||
else:
|
||||
valid = depth.flatten()
|
||||
else:
|
||||
valid = depth.flatten()
|
||||
|
||||
if valid.numel() > 100_000:
|
||||
idx = torch.randint(0, valid.numel(), (100_000,), device=valid.device)
|
||||
sample = valid[idx]
|
||||
else:
|
||||
sample = valid
|
||||
|
||||
lo = torch.quantile(sample, low_quantile)
|
||||
hi = torch.quantile(sample, high_quantile)
|
||||
rng = (hi - lo).clamp(min=1e-6)
|
||||
norm = ((depth - lo) / rng).clamp(0.0, 1.0)
|
||||
# ControlNet convention: nearer pixels are brighter (1.0).
|
||||
norm = 1.0 - norm
|
||||
if sky is not None:
|
||||
# Sky pixels become black (far / unknown).
|
||||
sky_mask = ~compute_non_sky_mask(sky)
|
||||
norm = torch.where(sky_mask, torch.zeros_like(norm), norm)
|
||||
return norm
|
||||
|
||||
|
||||
def normalize_depth_min_max(depth: torch.Tensor) -> torch.Tensor:
|
||||
"""Simple per-frame min/max normalization with near=1.0 convention."""
|
||||
lo = depth.amin(dim=(-2, -1), keepdim=True)
|
||||
hi = depth.amax(dim=(-2, -1), keepdim=True)
|
||||
rng = (hi - lo).clamp(min=1e-6)
|
||||
return 1.0 - ((depth - lo) / rng).clamp(0.0, 1.0)
|
||||
318
comfy/ldm/depth_anything_3/ray_pose.py
Normal file
318
comfy/ldm/depth_anything_3/ray_pose.py
Normal file
@ -0,0 +1,318 @@
|
||||
"""Ray-to-pose conversion for the multi-view path of Depth Anything 3.
|
||||
|
||||
Converts the auxiliary "ray" output of :class:`DualDPT` (per-pixel camera
|
||||
ray vectors, predicted on the per-view local feature map) into per-view
|
||||
extrinsics + intrinsics. Implementation is a 1:1 port of
|
||||
``depth_anything_3.utils.ray_utils`` upstream, using a weighted-RANSAC
|
||||
homography fit followed by a QL decomposition.
|
||||
|
||||
No learned parameters; pure tensor math. Output:
|
||||
|
||||
* ``R`` -- ``(B, S, 3, 3)`` rotation matrix
|
||||
* ``T`` -- ``(B, S, 3)`` camera-space translation
|
||||
* ``focal_lengths`` -- ``(B, S, 2)`` in normalised image space (image=2x2)
|
||||
* ``principal_points`` -- ``(B, S, 2)`` ditto
|
||||
|
||||
:func:`get_extrinsic_from_camray` wraps these into a 4x4 extrinsic matrix
|
||||
that the public node converts back into pixel-space intrinsics.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# qr/svd use fp32: CUDA often has no fp16/bf16 kernels for these ops.
|
||||
|
||||
|
||||
def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Decompose ``A = Q @ L`` with ``Q`` orthogonal and ``L`` lower-triangular.
|
||||
|
||||
Implemented in terms of QR by reversing the columns/rows; the standard
|
||||
trick from the upstream reference. Inputs ``A`` are ``(3, 3)``.
|
||||
"""
|
||||
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]],
|
||||
device=A.device, dtype=A.dtype)
|
||||
A_tilde = A @ P
|
||||
# CUDA QR is not implemented for fp16/bf16; upcast just for this call.
|
||||
Q_tilde, R_tilde = torch.linalg.qr(A_tilde.float())
|
||||
Q_tilde = Q_tilde.to(A.dtype)
|
||||
R_tilde = R_tilde.to(A.dtype)
|
||||
Q = Q_tilde @ P
|
||||
L = P @ R_tilde @ P
|
||||
d = torch.diag(L)
|
||||
sign = torch.sign(d)
|
||||
Q = Q * sign[None, :] # scale columns of Q
|
||||
L = L * sign[:, None] # scale rows of L
|
||||
return Q, L
|
||||
|
||||
|
||||
def _homogenize_points(points: torch.Tensor) -> torch.Tensor:
|
||||
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Weighted-LSQ + RANSAC homography (batched)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _find_homography_weighted_lsq(
|
||||
src_pts: torch.Tensor,
|
||||
dst_pts: torch.Tensor,
|
||||
confident_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Solve a single ``H`` with weighted least-squares (DLT)."""
|
||||
N = src_pts.shape[0]
|
||||
if N < 4:
|
||||
raise ValueError("At least 4 points are required to compute a homography.")
|
||||
w = confident_weight.sqrt().unsqueeze(1) # (N, 1)
|
||||
x = src_pts[:, 0:1]
|
||||
y = src_pts[:, 1:2]
|
||||
u = dst_pts[:, 0:1]
|
||||
v = dst_pts[:, 1:2]
|
||||
zeros = torch.zeros_like(x)
|
||||
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1)
|
||||
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1)
|
||||
A = torch.cat([A1, A2], dim=0) # (2N, 9)
|
||||
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
|
||||
_, _, Vh = torch.linalg.svd(A.float())
|
||||
Vh = Vh.to(A.dtype)
|
||||
H = Vh[-1].reshape(3, 3)
|
||||
return H / H[-1, -1]
|
||||
|
||||
|
||||
def _find_homography_weighted_lsq_batched(
|
||||
src_pts_batch: torch.Tensor,
|
||||
dst_pts_batch: torch.Tensor,
|
||||
confident_weight_batch: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Batched DLT solver. Inputs ``(B, K, 2)`` / ``(B, K)``; output ``(B, 3, 3)``."""
|
||||
B, K, _ = src_pts_batch.shape
|
||||
w = confident_weight_batch.sqrt().unsqueeze(2)
|
||||
x = src_pts_batch[:, :, 0:1]
|
||||
y = src_pts_batch[:, :, 1:2]
|
||||
u = dst_pts_batch[:, :, 0:1]
|
||||
v = dst_pts_batch[:, :, 1:2]
|
||||
zeros = torch.zeros_like(x)
|
||||
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2)
|
||||
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2)
|
||||
A = torch.cat([A1, A2], dim=1) # (B, 2K, 9)
|
||||
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
|
||||
_, _, Vh = torch.linalg.svd(A.float())
|
||||
Vh = Vh.to(A.dtype)
|
||||
H = Vh[:, -1].reshape(B, 3, 3)
|
||||
return H / H[:, 2:3, 2:3]
|
||||
|
||||
|
||||
def _ransac_find_homography_weighted_batched(
|
||||
src_pts: torch.Tensor, # (B, N, 2)
|
||||
dst_pts: torch.Tensor, # (B, N, 2)
|
||||
confident_weight: torch.Tensor, # (B, N)
|
||||
n_sample: int,
|
||||
n_iter: int = 100,
|
||||
reproj_threshold: float = 3.0,
|
||||
num_sample_for_ransac: int = 8,
|
||||
max_inlier_num: int = 10000,
|
||||
rand_sample_iters_idx: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Batched weighted-RANSAC homography estimator.
|
||||
|
||||
Returns ``(B, 3, 3)`` homography matrices.
|
||||
"""
|
||||
B, N, _ = src_pts.shape
|
||||
assert N >= 4
|
||||
device = src_pts.device
|
||||
|
||||
sorted_idx = torch.argsort(confident_weight, descending=True, dim=1)
|
||||
candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample)
|
||||
|
||||
if rand_sample_iters_idx is None:
|
||||
rand_sample_iters_idx = torch.stack(
|
||||
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac]
|
||||
for _ in range(n_iter)],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, k)
|
||||
b_idx = (
|
||||
torch.arange(B, device=device)
|
||||
.view(B, 1, 1)
|
||||
.expand(B, n_iter, num_sample_for_ransac)
|
||||
)
|
||||
src_b = src_pts[b_idx, rand_idx]
|
||||
dst_b = dst_pts[b_idx, rand_idx]
|
||||
w_b = confident_weight[b_idx, rand_idx]
|
||||
|
||||
cB, cN = src_b.shape[:2]
|
||||
H_batch = _find_homography_weighted_lsq_batched(
|
||||
src_b.flatten(0, 1), dst_b.flatten(0, 1), w_b.flatten(0, 1),
|
||||
).unflatten(0, (cB, cN)) # (B, n_iter, 3, 3)
|
||||
|
||||
src_homo = torch.cat([src_pts, torch.ones(B, N, 1, device=device, dtype=src_pts.dtype)], dim=2)
|
||||
proj = torch.bmm(
|
||||
src_homo.unsqueeze(1).expand(B, n_iter, N, 3).reshape(-1, N, 3),
|
||||
H_batch.reshape(-1, 3, 3).transpose(1, 2),
|
||||
) # (B*n_iter, N, 3)
|
||||
proj_xy = (proj[:, :, :2] / proj[:, :, 2:3]).reshape(B, n_iter, N, 2)
|
||||
err = ((proj_xy - dst_pts.unsqueeze(1)) ** 2).sum(-1).sqrt() # (B, n_iter, N)
|
||||
inlier_mask = err < reproj_threshold
|
||||
score = (inlier_mask * confident_weight.unsqueeze(1)).sum(dim=2)
|
||||
best_idx = torch.argmax(score, dim=1)
|
||||
best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx]
|
||||
|
||||
# Refit with the inlier set (per-batch, since the inlier counts vary).
|
||||
H_inlier_list = []
|
||||
for b in range(B):
|
||||
mask = best_inlier_mask[b]
|
||||
in_src = src_pts[b][mask]
|
||||
in_dst = dst_pts[b][mask]
|
||||
in_w = confident_weight[b][mask]
|
||||
if in_src.shape[0] < 4:
|
||||
# Fall back to identity when RANSAC fails to find enough inliers.
|
||||
H_inlier_list.append(torch.eye(3, device=device, dtype=src_pts.dtype))
|
||||
continue
|
||||
sorted_w = torch.argsort(in_w, descending=True)
|
||||
if len(sorted_w) > max_inlier_num:
|
||||
keep = max(int(len(sorted_w) * 0.95), max_inlier_num)
|
||||
sorted_w = sorted_w[:keep][torch.randperm(keep, device=device)[:max_inlier_num]]
|
||||
H_inlier_list.append(
|
||||
_find_homography_weighted_lsq(in_src[sorted_w], in_dst[sorted_w], in_w[sorted_w])
|
||||
)
|
||||
return torch.stack(H_inlier_list, dim=0)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Camera-ray utilities
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _unproject_identity(num_y: int, num_x: int, B: int, S: int,
|
||||
device, dtype) -> torch.Tensor:
|
||||
"""Camera-space unit rays for an identity intrinsic on a 2x2 image plane.
|
||||
|
||||
Replicates ``unproject_depth(..., ixt_normalized=True)`` upstream: pixel
|
||||
coords ``(x, y)`` in ``[dx, 2-dx] x [dy, 2-dy]`` get mapped to
|
||||
camera-space rays ``(x-1, y-1, 1)`` via the identity intrinsic
|
||||
``[[1,0,1],[0,1,1],[0,0,1]]``. Returns ``(B, S, num_y, num_x, 3)``.
|
||||
"""
|
||||
dx = 1.0 / num_x
|
||||
dy = 1.0 / num_y
|
||||
# Centered camera-space coords directly (skip the K^-1 step since it's
|
||||
# just a translation by -1 on x and y when K is identity-with-center=1).
|
||||
y = torch.linspace(-(1 - dy), (1 - dy), num_y, device=device, dtype=dtype)
|
||||
x = torch.linspace(-(1 - dx), (1 - dx), num_x, device=device, dtype=dtype)
|
||||
yy, xx = torch.meshgrid(y, x, indexing="ij")
|
||||
grid = torch.stack((xx, yy), dim=-1) # (h, w, 2)
|
||||
grid = grid.unsqueeze(0).unsqueeze(0).expand(B, S, num_y, num_x, 2)
|
||||
return torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)
|
||||
|
||||
|
||||
def _camray_to_caminfo(
|
||||
camray: torch.Tensor, # (B, S, h, w, 6)
|
||||
confidence: Optional[torch.Tensor] = None, # (B, S, h, w)
|
||||
reproj_threshold: float = 0.2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Convert per-pixel camera rays to per-view (R, T, focal, principal)."""
|
||||
if confidence is None:
|
||||
confidence = torch.ones_like(camray[..., 0])
|
||||
B, S, h, w, _ = camray.shape
|
||||
device = camray.device
|
||||
dtype = camray.dtype
|
||||
|
||||
rays_target = camray[..., :3] # (B, S, h, w, 3)
|
||||
rays_origin = _unproject_identity(h, w, B, S, device, dtype)
|
||||
|
||||
# Flatten (B*S, h*w, *) for the RANSAC routine.
|
||||
rays_target = rays_target.flatten(0, 1).flatten(1, 2)
|
||||
rays_origin = rays_origin.flatten(0, 1).flatten(1, 2)
|
||||
weights = confidence.flatten(0, 1).flatten(1, 2).clone()
|
||||
|
||||
# Project to 2D in homogeneous form (the upstream calls this "perspective division").
|
||||
z_thresh = 1e-4
|
||||
mask = (rays_target[:, :, 2].abs() > z_thresh) & (rays_origin[:, :, 2].abs() > z_thresh)
|
||||
weights = torch.where(mask, weights, torch.zeros_like(weights))
|
||||
src = rays_origin.clone()
|
||||
dst = rays_target.clone()
|
||||
src[..., 0] = torch.where(mask, src[..., 0] / src[..., 2], src[..., 0])
|
||||
src[..., 1] = torch.where(mask, src[..., 1] / src[..., 2], src[..., 1])
|
||||
dst[..., 0] = torch.where(mask, dst[..., 0] / dst[..., 2], dst[..., 0])
|
||||
dst[..., 1] = torch.where(mask, dst[..., 1] / dst[..., 2], dst[..., 1])
|
||||
src = src[..., :2]
|
||||
dst = dst[..., :2]
|
||||
|
||||
N = src.shape[1]
|
||||
n_iter = 100
|
||||
sample_ratio = 0.3
|
||||
num_sample_for_ransac = 8
|
||||
n_sample = max(num_sample_for_ransac, int(N * sample_ratio))
|
||||
rand_idx = torch.stack(
|
||||
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Chunk along the view axis to keep peak memory predictable.
|
||||
chunk = 2
|
||||
A_list = []
|
||||
for i in range(0, src.shape[0], chunk):
|
||||
A = _ransac_find_homography_weighted_batched(
|
||||
src[i:i + chunk], dst[i:i + chunk], weights[i:i + chunk],
|
||||
n_sample=n_sample, n_iter=n_iter,
|
||||
num_sample_for_ransac=num_sample_for_ransac,
|
||||
reproj_threshold=reproj_threshold,
|
||||
rand_sample_iters_idx=rand_idx,
|
||||
max_inlier_num=8000,
|
||||
)
|
||||
# Flip sign on dets that come out < 0 (so that the QL produces a
|
||||
# right-handed rotation). ``det`` lacks fp16/bf16 CUDA kernels, so
|
||||
# do the comparison in fp32.
|
||||
flip = torch.linalg.det(A.float()) < 0
|
||||
A = torch.where(flip[:, None, None], -A, A)
|
||||
A_list.append(A)
|
||||
A = torch.cat(A_list, dim=0) # (B*S, 3, 3)
|
||||
|
||||
R_list, f_list, pp_list = [], [], []
|
||||
for i in range(A.shape[0]):
|
||||
R, L = _ql_decomposition(A[i])
|
||||
L = L / L[2][2]
|
||||
f_list.append(torch.stack((L[0][0], L[1][1])))
|
||||
pp_list.append(torch.stack((L[2][0], L[2][1])))
|
||||
R_list.append(R)
|
||||
R = torch.stack(R_list).reshape(B, S, 3, 3)
|
||||
focal = torch.stack(f_list).reshape(B, S, 2)
|
||||
pp = torch.stack(pp_list).reshape(B, S, 2)
|
||||
|
||||
# Translation: confidence-weighted average of camray direction(s).
|
||||
cf = confidence.flatten(0, 1).flatten(1, 2)
|
||||
T = (camray.flatten(0, 1).flatten(1, 2)[..., 3:] * cf.unsqueeze(-1)).sum(dim=1)
|
||||
T = T / cf.sum(dim=-1, keepdim=True)
|
||||
T = T.reshape(B, S, 3)
|
||||
|
||||
# Match upstream output convention: focal -> 1/focal, pp + 1.
|
||||
return R, T, 1.0 / focal, pp + 1.0
|
||||
|
||||
|
||||
def get_extrinsic_from_camray(
|
||||
camray: torch.Tensor, # (B, S, h, w, 6)
|
||||
conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w)
|
||||
patch_size_y: int,
|
||||
patch_size_x: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Wrap a 4x4 extrinsic + per-view focal + principal-point output.
|
||||
|
||||
Returns:
|
||||
* extrinsic ``(B, S, 4, 4)`` camera-to-world (the inverse is
|
||||
what gets stored in ``output.extrinsics``
|
||||
by the caller).
|
||||
* focals ``(B, S, 2)`` in normalised image space.
|
||||
* pp ``(B, S, 2)`` in normalised image space.
|
||||
"""
|
||||
if conf.ndim == 5 and conf.shape[-1] == 1:
|
||||
conf = conf.squeeze(-1)
|
||||
R, T, focal, pp = _camray_to_caminfo(camray, confidence=conf)
|
||||
extr = torch.cat([R, T.unsqueeze(-1)], dim=-1) # (B, S, 3, 4)
|
||||
homo_row = torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device)
|
||||
homo_row = homo_row.view(1, 1, 1, 4).expand(R.shape[0], R.shape[1], 1, 4)
|
||||
extr = torch.cat([extr, homo_row], dim=-2) # (B, S, 4, 4)
|
||||
return extr, focal, pp
|
||||
116
comfy/ldm/depth_anything_3/reference_view_selector.py
Normal file
116
comfy/ldm/depth_anything_3/reference_view_selector.py
Normal file
@ -0,0 +1,116 @@
|
||||
"""Reference-view selection for the multi-view path of Depth Anything 3.
|
||||
|
||||
Pure tensor math, no learned parameters. Exposed as three free functions:
|
||||
|
||||
* :func:`select_reference_view` -- pick a reference view per batch.
|
||||
* :func:`reorder_by_reference` -- move the reference view to position 0.
|
||||
* :func:`restore_original_order` -- inverse of :func:`reorder_by_reference`.
|
||||
|
||||
Mirrors ``depth_anything_3.model.reference_view_selector`` upstream.
|
||||
The default strategy (``"saddle_balanced"``) selects the view whose CLS
|
||||
token features are closest to the median across multiple metrics.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"]
|
||||
|
||||
|
||||
# Per the upstream constants module: ``THRESH_FOR_REF_SELECTION = 3``.
|
||||
# Reference selection only runs when there are at least this many views.
|
||||
THRESH_FOR_REF_SELECTION: int = 3
|
||||
|
||||
|
||||
def select_reference_view(
|
||||
x: torch.Tensor,
|
||||
strategy: RefViewStrategy = "saddle_balanced",
|
||||
) -> torch.Tensor:
|
||||
"""Pick a reference view index per batch element.
|
||||
|
||||
Args:
|
||||
x: ``(B, S, N, C)`` token tensor. Index 0 along ``N`` is the
|
||||
cls/cam token used by the feature-based strategies.
|
||||
strategy: One of ``"first" | "middle" | "saddle_balanced" |
|
||||
"saddle_sim_range"``.
|
||||
|
||||
Returns:
|
||||
``(B,)`` long tensor with the chosen reference view index for
|
||||
each batch element.
|
||||
"""
|
||||
B, S, _, _ = x.shape
|
||||
if S <= 1:
|
||||
return torch.zeros(B, dtype=torch.long, device=x.device)
|
||||
if strategy == "first":
|
||||
return torch.zeros(B, dtype=torch.long, device=x.device)
|
||||
if strategy == "middle":
|
||||
return torch.full((B,), S // 2, dtype=torch.long, device=x.device)
|
||||
|
||||
# Feature-based strategies: normalised cls/cam token per view.
|
||||
img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # (B,S,C)
|
||||
|
||||
if strategy == "saddle_balanced":
|
||||
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # (B,S,S)
|
||||
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
|
||||
sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # (B,S)
|
||||
feat_norm = x[:, :, 0].norm(dim=-1) # (B,S)
|
||||
feat_var = img_class_feat.var(dim=-1) # (B,S)
|
||||
|
||||
def _normalize(metric):
|
||||
mn = metric.min(dim=1, keepdim=True).values
|
||||
mx = metric.max(dim=1, keepdim=True).values
|
||||
return (metric - mn) / (mx - mn + 1e-8)
|
||||
|
||||
sim_n, norm_n, var_n = _normalize(sim_score), _normalize(feat_norm), _normalize(feat_var)
|
||||
balance = (sim_n - 0.5).abs() + (norm_n - 0.5).abs() + (var_n - 0.5).abs()
|
||||
return balance.argmin(dim=1)
|
||||
|
||||
if strategy == "saddle_sim_range":
|
||||
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2))
|
||||
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
|
||||
sim_max = sim_no_diag.max(dim=-1).values
|
||||
sim_min = sim_no_diag.min(dim=-1).values
|
||||
return (sim_max - sim_min).argmax(dim=1)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown reference view selection strategy: {strategy!r}. "
|
||||
f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'"
|
||||
)
|
||||
|
||||
|
||||
def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
|
||||
"""Reorder ``x`` so the reference view is at position 0 in axis ``S``."""
|
||||
B, S = x.shape[0], x.shape[1]
|
||||
if S <= 1:
|
||||
return x
|
||||
positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||
b_idx_exp = b_idx.unsqueeze(1)
|
||||
reorder = torch.where(
|
||||
(positions > 0) & (positions <= b_idx_exp),
|
||||
positions - 1,
|
||||
positions,
|
||||
)
|
||||
reorder[:, 0] = b_idx
|
||||
batch = torch.arange(B, device=x.device).unsqueeze(1)
|
||||
return x[batch, reorder]
|
||||
|
||||
|
||||
def restore_original_order(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
|
||||
"""Inverse of :func:`reorder_by_reference`."""
|
||||
B, S = x.shape[0], x.shape[1]
|
||||
if S <= 1:
|
||||
return x
|
||||
target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||
b_idx_exp = b_idx.unsqueeze(1)
|
||||
restore = torch.where(target_positions < b_idx_exp,
|
||||
target_positions + 1,
|
||||
target_positions)
|
||||
restore = torch.scatter(
|
||||
restore, dim=1, index=b_idx_exp, src=torch.zeros_like(b_idx_exp),
|
||||
)
|
||||
batch = torch.arange(B, device=x.device).unsqueeze(1)
|
||||
return x[batch, restore]
|
||||
183
comfy/ldm/depth_anything_3/transform.py
Normal file
183
comfy/ldm/depth_anything_3/transform.py
Normal file
@ -0,0 +1,183 @@
|
||||
"""Geometry / camera transform helpers for Depth Anything 3.
|
||||
|
||||
Pure tensor math, no learned parameters. Mirrors the upstream upstream
|
||||
``depth_anything_3.model.utils.transform`` and the parts of
|
||||
``depth_anything_3.utils.geometry`` used at inference time on the
|
||||
multi-view + camera path. Kept self-contained so the DA3 module is fully
|
||||
ported and does not depend on the upstream repo at runtime.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Affine 4x4 helpers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def as_homogeneous(ext: torch.Tensor) -> torch.Tensor:
|
||||
"""Promote ``(...,3,4)`` extrinsics to ``(...,4,4)`` homogeneous form.
|
||||
|
||||
A no-op when the input is already ``(...,4,4)``.
|
||||
"""
|
||||
if ext.shape[-2:] == (4, 4):
|
||||
return ext
|
||||
if ext.shape[-2:] == (3, 4):
|
||||
ones = torch.zeros_like(ext[..., :1, :4])
|
||||
ones[..., 0, 3] = 1.0
|
||||
return torch.cat([ext, ones], dim=-2)
|
||||
raise ValueError(f"Invalid affine shape: {ext.shape}")
|
||||
|
||||
|
||||
def affine_inverse(A: torch.Tensor) -> torch.Tensor:
|
||||
"""Inverse of an affine matrix ``[R|T; 0 0 0 1]``."""
|
||||
R = A[..., :3, :3]
|
||||
T = A[..., :3, 3:]
|
||||
P = A[..., 3:, :]
|
||||
return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Quaternion <-> rotation matrix (xyzw / scalar-last)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
||||
"""``sqrt(max(0, x))`` with a zero subgradient where ``x == 0``."""
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
if torch.is_grad_enabled():
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
else:
|
||||
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
||||
return ret
|
||||
|
||||
|
||||
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""Force the real part of a unit quaternion (xyzw) to be non-negative."""
|
||||
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
||||
|
||||
|
||||
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert quaternions (xyzw) to ``(...,3,3)`` rotation matrices."""
|
||||
i, j, k, r = torch.unbind(quaternions, -1)
|
||||
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
||||
o = torch.stack(
|
||||
(
|
||||
1 - two_s * (j * j + k * k),
|
||||
two_s * (i * j - k * r),
|
||||
two_s * (i * k + j * r),
|
||||
two_s * (i * j + k * r),
|
||||
1 - two_s * (i * i + k * k),
|
||||
two_s * (j * k - i * r),
|
||||
two_s * (i * k - j * r),
|
||||
two_s * (j * k + i * r),
|
||||
1 - two_s * (i * i + j * j),
|
||||
),
|
||||
-1,
|
||||
)
|
||||
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
||||
|
||||
|
||||
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert ``(...,3,3)`` rotation matrices to quaternions (xyzw)."""
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
||||
|
||||
batch_dim = matrix.shape[:-2]
|
||||
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
||||
matrix.reshape(batch_dim + (9,)), dim=-1
|
||||
)
|
||||
|
||||
q_abs = _sqrt_positive_part(
|
||||
torch.stack(
|
||||
[
|
||||
1.0 + m00 + m11 + m22,
|
||||
1.0 + m00 - m11 - m22,
|
||||
1.0 - m00 + m11 - m22,
|
||||
1.0 - m00 - m11 + m22,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
|
||||
quat_by_rijk = torch.stack(
|
||||
[
|
||||
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
||||
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
||||
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
||||
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
||||
],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
||||
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
||||
|
||||
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
|
||||
batch_dim + (4,)
|
||||
)
|
||||
# Reorder rijk -> xyzw (i.e. ijkr).
|
||||
out = out[..., [1, 2, 3, 0]]
|
||||
return standardize_quaternion(out)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Pose-encoding <-> extrinsics + intrinsics
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extri_intri_to_pose_encoding(
|
||||
extrinsics: torch.Tensor,
|
||||
intrinsics: torch.Tensor,
|
||||
image_size_hw: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""Pack ``(extr, intr, image_size)`` into the 9-D pose-encoding vector.
|
||||
|
||||
``extrinsics`` are camera-to-world (c2w) ``(B,S,4,4)`` matrices,
|
||||
``intrinsics`` are pixel-space ``(B,S,3,3)`` matrices, ``image_size_hw``
|
||||
is a ``(H, W)`` pair. The encoding is ``[T(3), quat_xyzw(4), fov_h, fov_w]``.
|
||||
"""
|
||||
R = extrinsics[..., :3, :3]
|
||||
T = extrinsics[..., :3, 3]
|
||||
quat = mat_to_quat(R)
|
||||
H, W = image_size_hw
|
||||
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
|
||||
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
|
||||
return torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
|
||||
|
||||
|
||||
def pose_encoding_to_extri_intri(
|
||||
pose_encoding: torch.Tensor,
|
||||
image_size_hw: Tuple[int, int],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Inverse of :func:`extri_intri_to_pose_encoding`.
|
||||
|
||||
Returns a ``(B,S,3,4)`` c2w extrinsic matrix and a ``(B,S,3,3)``
|
||||
pixel-space intrinsic matrix.
|
||||
"""
|
||||
T = pose_encoding[..., :3]
|
||||
quat = pose_encoding[..., 3:7]
|
||||
fov_h = pose_encoding[..., 7]
|
||||
fov_w = pose_encoding[..., 8]
|
||||
# Normalize to unit quaternion. CameraDec outputs raw values; a near-zero
|
||||
# quaternion causes two_s = 2/norm² → inf in quat_to_mat → NaN extrinsics.
|
||||
quat = quat / quat.norm(dim=-1, keepdim=True).clamp(min=1e-6)
|
||||
R = quat_to_mat(quat)
|
||||
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
||||
H, W = image_size_hw
|
||||
fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6)
|
||||
fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6)
|
||||
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3),
|
||||
device=pose_encoding.device, dtype=pose_encoding.dtype)
|
||||
intrinsics[..., 0, 0] = fx
|
||||
intrinsics[..., 1, 1] = fy
|
||||
intrinsics[..., 0, 2] = W / 2
|
||||
intrinsics[..., 1, 2] = H / 2
|
||||
intrinsics[..., 2, 2] = 1.0
|
||||
return extrinsics, intrinsics
|
||||
@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module):
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if output_size is None:
|
||||
output_size = hidden_size
|
||||
@ -221,10 +221,9 @@ class TimestepEmbedder(nn.Module):
|
||||
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.max_period = max_period
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype)
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
@ -1,239 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.flux.math import apply_rope, rope
|
||||
from comfy.ldm.hidream.model import FeedForwardSwiGLU
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
|
||||
from .modules import (
|
||||
FinalLayer,
|
||||
PatchTokenEmbedder,
|
||||
PiTBlock,
|
||||
PixelTokenEmbedder,
|
||||
apply_adaln_,
|
||||
precompute_freqs_cis_2d,
|
||||
)
|
||||
|
||||
|
||||
class MMDiTJointAttention(nn.Module):
|
||||
"""Joint MMDiT attention with separate Q/K/V/proj for image and text streams.
|
||||
|
||||
RoPE is applied to each stream before concatenation so each stream uses its own
|
||||
2D/1D positional encoding. Concat order is [text, image] (text first).
|
||||
"""
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
|
||||
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
|
||||
B, Nx, _ = x.shape
|
||||
_, Ny, _ = y.shape
|
||||
H = self.num_heads
|
||||
D = self.head_dim
|
||||
|
||||
qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
qx, kx, vx = qkv_x.unbind(0)
|
||||
qx = self.q_norm_x(qx)
|
||||
kx = self.k_norm_x(kx)
|
||||
|
||||
qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
qy, ky, vy = qkv_y.unbind(0)
|
||||
qy = self.q_norm_y(qy)
|
||||
ky = self.k_norm_y(ky)
|
||||
|
||||
qx, kx = apply_rope(qx, kx, pos_img[None, None])
|
||||
if pos_txt is not None:
|
||||
qy, ky = apply_rope(qy, ky, pos_txt[None, None])
|
||||
|
||||
q_joint = torch.cat([qy, qx], dim=2)
|
||||
k_joint = torch.cat([ky, kx], dim=2)
|
||||
v_joint = torch.cat([vy, vx], dim=2)
|
||||
|
||||
out_joint = optimized_attention(
|
||||
q_joint, k_joint, v_joint, H,
|
||||
mask=attn_mask, skip_reshape=True, skip_output_reshape=True,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D)
|
||||
out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D)
|
||||
|
||||
return self.proj_x(out_x), self.proj_y(out_y)
|
||||
|
||||
|
||||
class MMDiTBlockT2I(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False, dtype=dtype, device=device, operations=operations)
|
||||
self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
|
||||
self.adaLN_modulation_img = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
self.adaLN_modulation_txt = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
|
||||
shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
|
||||
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
|
||||
|
||||
x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x)
|
||||
y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y)
|
||||
attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
|
||||
x = torch.addcmul(x, gate_msa_x, attn_x)
|
||||
y = torch.addcmul(y, gate_msa_y, attn_y)
|
||||
|
||||
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
|
||||
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
|
||||
return x, y
|
||||
|
||||
|
||||
class PixDiT_T2I(nn.Module):
|
||||
"""PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint
|
||||
(also runs at 512px when fed the appropriate latent size and flow_shift).
|
||||
|
||||
Forward:
|
||||
x: [B, 3, H, W] pixel-space input (no VAE)
|
||||
timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention)
|
||||
context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended)
|
||||
Returns flow-matching velocity [B, 3, H, W].
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
num_groups=24,
|
||||
hidden_size=1536,
|
||||
pixel_hidden_size=16,
|
||||
pixel_attn_hidden_size=1152,
|
||||
pixel_num_groups=16,
|
||||
patch_depth=14,
|
||||
pixel_depth=2,
|
||||
patch_size=16,
|
||||
txt_embed_dim=2304,
|
||||
txt_max_length=300,
|
||||
use_text_rope=True,
|
||||
text_rope_theta=10000.0,
|
||||
image_model=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
pixel_mlp_chunks=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.patch_depth = patch_depth
|
||||
self.pixel_depth = pixel_depth
|
||||
self.patch_size = patch_size
|
||||
self.pixel_hidden_size = pixel_hidden_size
|
||||
self.pixel_attn_hidden_size = pixel_attn_hidden_size
|
||||
self.pixel_num_groups = pixel_num_groups
|
||||
self.txt_embed_dim = txt_embed_dim
|
||||
self.txt_max_length = txt_max_length
|
||||
self.use_text_rope = use_text_rope
|
||||
self.text_rope_theta = text_rope_theta
|
||||
|
||||
self.pixel_embedder = PixelTokenEmbedder(self.in_channels, self.pixel_hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.s_embedder = PatchTokenEmbedder(self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True, dtype=dtype, device=device, operations=operations)
|
||||
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10)
|
||||
self.y_embedder = PatchTokenEmbedder(self.txt_embed_dim, self.hidden_size, bias=True, use_norm=True, dtype=dtype, device=device, operations=operations)
|
||||
self.y_pos_embedding = nn.Parameter(torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device))
|
||||
|
||||
self.patch_blocks = nn.ModuleList([
|
||||
MMDiTBlockT2I(self.hidden_size, self.num_groups,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(self.patch_depth)
|
||||
])
|
||||
self.pixel_blocks = nn.ModuleList([
|
||||
PiTBlock(
|
||||
self.pixel_hidden_size,
|
||||
self.hidden_size,
|
||||
patch_size=self.patch_size,
|
||||
num_heads=self.num_groups,
|
||||
attn_hidden_size=self.pixel_attn_hidden_size,
|
||||
attn_num_heads=self.pixel_num_groups,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
mlp_chunks=pixel_mlp_chunks,
|
||||
)
|
||||
for _ in range(self.pixel_depth)
|
||||
])
|
||||
|
||||
self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
|
||||
return precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width, device=device, dtype=dtype, **rope_opts)
|
||||
|
||||
def _fetch_text_pos(self, length, device, dtype):
|
||||
return rope(torch.arange(length, dtype=torch.float32, device=device).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0).to(dtype=dtype)
|
||||
|
||||
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
|
||||
|
||||
def _pre_patch_block(self, s, i, **kwargs):
|
||||
"""Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate)."""
|
||||
return s
|
||||
|
||||
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
H_orig, W_orig = x.shape[2], x.shape[3]
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
B, _, H, W = x.shape
|
||||
Hs = H // self.patch_size
|
||||
Ws = W // self.patch_size
|
||||
L = Hs * Ws
|
||||
|
||||
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
|
||||
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
|
||||
t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
|
||||
|
||||
if context is None or context.dim() != 3:
|
||||
raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]")
|
||||
Ltxt = min(context.shape[1], self.txt_max_length)
|
||||
y = context[:, :Ltxt, :]
|
||||
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
|
||||
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter
|
||||
|
||||
condition = F.silu(t_emb)
|
||||
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
|
||||
|
||||
s = self.s_embedder(x_patches)
|
||||
for i, blk in enumerate(self.patch_blocks):
|
||||
s = self._pre_patch_block(s, i, **kwargs)
|
||||
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
|
||||
s = F.silu(t_emb + s)
|
||||
|
||||
s_cond = s.view(B * L, self.hidden_size)
|
||||
x_pixels = self.pixel_embedder(x, patch_size=self.patch_size)
|
||||
for blk in self.pixel_blocks:
|
||||
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options)
|
||||
|
||||
x_pixels = self.final_layer(x_pixels)
|
||||
C_out = self.out_channels
|
||||
P2 = self.patch_size * self.patch_size
|
||||
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).reshape(B, C_out * P2, L)
|
||||
out = F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return out[:, :, :H_orig, :W_orig]
|
||||
@ -1,187 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope, rope
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||
|
||||
|
||||
def apply_adaln_(x, shift, scale):
|
||||
return x.addcmul_(x, scale).add_(shift)
|
||||
|
||||
|
||||
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0,
|
||||
ref_grid_h=None, ref_grid_w=None,
|
||||
scale_x=1.0, scale_y=1.0, shift_x=0.0, shift_y=0.0,
|
||||
device=None, dtype=torch.float32, **kwargs):
|
||||
"""2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim.
|
||||
|
||||
rope_options:
|
||||
scale_x / scale_y multiply the position range (RoPE extrapolation).
|
||||
shift_x / shift_y offset the position origin (tiled / regional inference).
|
||||
With ref_grid_h/w set, also applies NTK-aware per-axis theta scaling
|
||||
(rope_mode='ntk_aware'): theta_axis = theta * (current/ref)^(dim_axis/(dim_axis-2)).
|
||||
Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2].
|
||||
Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}].
|
||||
"""
|
||||
dim_axis = dim // 2
|
||||
if ref_grid_h is not None and dim_axis > 2:
|
||||
h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2))
|
||||
w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2))
|
||||
else:
|
||||
h_ntk = w_ntk = 1.0
|
||||
|
||||
x_lin = torch.linspace(shift_x, scale * scale_x + shift_x, width, device=device)
|
||||
y_lin = torch.linspace(shift_y, scale * scale_y + shift_y, height, device=device)
|
||||
y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij")
|
||||
x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0)
|
||||
y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0)
|
||||
out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2)
|
||||
return out.to(dtype=dtype)
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32):
|
||||
"""Standard 2D sin/cos absolute positional embedding (ViT-style).
|
||||
|
||||
first half encodes W-coordinates, second half H.
|
||||
"""
|
||||
assert embed_dim % 4 == 0
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=device)
|
||||
grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij")
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_x.reshape(-1), device=device)
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_y.reshape(-1), device=device)
|
||||
return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype)
|
||||
|
||||
|
||||
class RotaryAttention(nn.Module):
|
||||
"""Single-stream self-attention with rotary positional encoding (used inside PiTBlock)."""
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, pos, mask=None, transformer_options={}):
|
||||
B, N, C = x.shape
|
||||
H = self.num_heads
|
||||
D = self.head_dim
|
||||
qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = apply_rope(self.q_norm(q), self.k_norm(k), pos[None, None])
|
||||
x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.norm(x))
|
||||
|
||||
|
||||
class PatchTokenEmbedder(nn.Module):
|
||||
"""Linear projection used both for patchified-image tokens and text-feature tokens."""
|
||||
def __init__(self, in_chans, embed_dim, use_norm=False, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device) if use_norm else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.proj(x))
|
||||
|
||||
|
||||
class PixelTokenEmbedder(nn.Module):
|
||||
"""Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences."""
|
||||
def __init__(self, in_channels, hidden_size_output, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_size_output = hidden_size_output
|
||||
self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, inputs, patch_size):
|
||||
B, _, H, W = inputs.shape
|
||||
Hs, Ws = H // patch_size, W // patch_size
|
||||
P2 = patch_size * patch_size
|
||||
x = inputs.permute(0, 2, 3, 1).contiguous()
|
||||
x = self.proj(x)
|
||||
pos_full = get_2d_sincos_pos_embed(self.hidden_size_output, H, W, device=x.device, dtype=x.dtype).view(H, W, self.hidden_size_output)
|
||||
x = x + pos_full.unsqueeze(0)
|
||||
x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output)
|
||||
return x.permute(0, 1, 3, 2, 4, 5).reshape(B * Hs * Ws, P2, self.hidden_size_output)
|
||||
|
||||
|
||||
class PiTBlock(nn.Module):
|
||||
"""Pixel-level transformer block.
|
||||
|
||||
Compresses each patch's P^2 pixel tokens → 1 attention token via a linear,
|
||||
runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens.
|
||||
Conditioning is per-pixel adaLN from the patch-level features.
|
||||
"""
|
||||
def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
|
||||
attn_hidden_size=None, attn_num_heads=None, dtype=None, device=None, operations=None, mlp_chunks=1):
|
||||
super().__init__()
|
||||
self.pixel_dim = pixel_hidden_size
|
||||
self.context_dim = patch_hidden_size
|
||||
self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size
|
||||
self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads
|
||||
assert self.attn_dim % self.num_heads == 0
|
||||
|
||||
p2 = patch_size * patch_size
|
||||
self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device)
|
||||
self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False, dtype=dtype, device=device, operations=operations)
|
||||
self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self._rope_fn = precompute_freqs_cis_2d
|
||||
self.mlp_chunks = max(1, int(mlp_chunks))
|
||||
|
||||
def _fetch_pos(self, height, width, device, dtype, **rope_opts):
|
||||
return self._rope_fn(self.attn_dim // self.num_heads, height, width, device=device, dtype=dtype, **rope_opts)
|
||||
|
||||
def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}):
|
||||
BL, P2, _ = x.shape
|
||||
Hs, Ws = image_height // patch_size, image_width // patch_size
|
||||
L = Hs * Ws
|
||||
B = BL // L
|
||||
|
||||
# Attention path uses only msa params; compute, use, free before mlp params allocate.
|
||||
msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim)
|
||||
shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1)
|
||||
|
||||
x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa)
|
||||
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
|
||||
|
||||
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
|
||||
pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
|
||||
attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options)
|
||||
attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
|
||||
attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
|
||||
x = torch.addcmul(x, gate_msa, attn_exp)
|
||||
del msa_params, shift_msa, scale_msa, gate_msa
|
||||
|
||||
mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
|
||||
shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
|
||||
gate_mlp = gate_mlp.contiguous() # detach from mlp_params so the del below frees shift+scale storage before the MLP
|
||||
mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp)
|
||||
del mlp_params, shift_mlp, scale_mlp
|
||||
|
||||
# MLP in chunks since the peak memory usage is huge here
|
||||
chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks
|
||||
for s in range(0, BL, chunk_size):
|
||||
e = min(s + chunk_size, BL)
|
||||
x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e]))
|
||||
return x
|
||||
@ -1,226 +0,0 @@
|
||||
"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent
|
||||
directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
|
||||
body + LQ projection branch injected before each MMDiT patch block.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .model import PixDiT_T2I
|
||||
from .modules import precompute_freqs_cis_2d
|
||||
|
||||
|
||||
class SigmaAwareGatePerTokenPerDim(nn.Module):
|
||||
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
|
||||
|
||||
Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device)
|
||||
self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
||||
content_logit = self.content_proj(torch.cat([x, lq], dim=-1))
|
||||
# log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
|
||||
log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32)
|
||||
sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1)
|
||||
gate = torch.sigmoid(content_logit + sigma_offset)
|
||||
return x + (gate * lq).to(x.dtype)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip."""
|
||||
|
||||
def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.block(x)
|
||||
|
||||
|
||||
class LQProjection2D(nn.Module):
|
||||
"""LQ latent -> per-block patch-aligned features for controlnet-style injection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
latent_channels: int,
|
||||
hidden_dim: int = 512,
|
||||
out_dim: int = 1536,
|
||||
patch_size: int = 16,
|
||||
sr_scale: int = 4,
|
||||
latent_spatial_down_factor: int = 8,
|
||||
num_res_blocks: int = 4,
|
||||
num_outputs: int = 7,
|
||||
interval: int = 2,
|
||||
dtype=None, device=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
self.hidden_dim = hidden_dim
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
self.sr_scale = sr_scale
|
||||
self.latent_spatial_down_factor = latent_spatial_down_factor
|
||||
self.num_outputs = num_outputs
|
||||
self.interval = interval
|
||||
|
||||
z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size
|
||||
self.z_to_patch_ratio = z_to_patch_ratio
|
||||
if z_to_patch_ratio >= 1:
|
||||
self.latent_fold_factor = 0
|
||||
latent_proj_in_ch = latent_channels
|
||||
else:
|
||||
fold_factor = int(1 / z_to_patch_ratio)
|
||||
assert fold_factor * z_to_patch_ratio == 1.0
|
||||
self.latent_fold_factor = fold_factor
|
||||
latent_proj_in_ch = latent_channels * fold_factor * fold_factor
|
||||
|
||||
layers = [
|
||||
operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
]
|
||||
for _ in range(num_res_blocks):
|
||||
layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations))
|
||||
self.latent_proj = nn.Sequential(*layers)
|
||||
|
||||
self.output_heads = nn.ModuleList(
|
||||
[operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)]
|
||||
)
|
||||
self.gate_modules = nn.ModuleList(
|
||||
[SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_outputs)]
|
||||
)
|
||||
|
||||
def is_gate_active(self, block_idx: int) -> bool:
|
||||
return block_idx % self.interval == 0
|
||||
|
||||
def output_index(self, block_idx: int) -> int:
|
||||
return block_idx // self.interval
|
||||
|
||||
def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor:
|
||||
return self.gate_modules[out_idx](x, lq_feature, sigma)
|
||||
|
||||
def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor:
|
||||
B, z_dim = lq_latent.shape[:2]
|
||||
if self.z_to_patch_ratio >= 1:
|
||||
if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW:
|
||||
z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest")
|
||||
else:
|
||||
z_aligned = lq_latent
|
||||
else:
|
||||
f = self.latent_fold_factor
|
||||
zH_expected, zW_expected = pH * f, pW * f
|
||||
if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected:
|
||||
lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest")
|
||||
z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4)
|
||||
z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW)
|
||||
return self.latent_proj(z_aligned)
|
||||
|
||||
def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]:
|
||||
feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW)
|
||||
B, C, H, W = feat.shape
|
||||
tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)
|
||||
return [head(tokens) for head in self.output_heads]
|
||||
|
||||
|
||||
class PidNet(PixDiT_T2I):
|
||||
"""PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lq_latent_channels: int = 16,
|
||||
lq_hidden_dim: int = 512,
|
||||
lq_num_res_blocks: int = 4,
|
||||
lq_interval: int = 2,
|
||||
sr_scale: int = 4,
|
||||
latent_spatial_down_factor: int = 8,
|
||||
rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64.
|
||||
rope_ref_w: int = 1024,
|
||||
image_model=None,
|
||||
dtype=None, device=None, operations=None,
|
||||
**pixdit_kwargs,
|
||||
):
|
||||
super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs)
|
||||
|
||||
self.rope_ref_grid_h = rope_ref_h // self.patch_size
|
||||
self.rope_ref_grid_w = rope_ref_w // self.patch_size
|
||||
|
||||
# Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware.
|
||||
def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts):
|
||||
return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts)
|
||||
for blk in self.pixel_blocks:
|
||||
blk._rope_fn = _pit_rope_fn
|
||||
|
||||
num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval
|
||||
self.lq_proj = LQProjection2D(
|
||||
latent_channels=lq_latent_channels,
|
||||
hidden_dim=lq_hidden_dim,
|
||||
out_dim=self.hidden_size,
|
||||
patch_size=self.patch_size,
|
||||
sr_scale=sr_scale,
|
||||
latent_spatial_down_factor=latent_spatial_down_factor,
|
||||
num_res_blocks=lq_num_res_blocks,
|
||||
num_outputs=num_lq_outputs,
|
||||
interval=lq_interval,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
|
||||
return precompute_freqs_cis_2d(
|
||||
self.hidden_size // self.num_groups,
|
||||
height, width,
|
||||
ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w,
|
||||
device=device, dtype=dtype, **rope_opts,
|
||||
)
|
||||
|
||||
def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs):
|
||||
if not self.lq_proj.is_gate_active(i):
|
||||
return s
|
||||
out_idx = self.lq_proj.output_index(i)
|
||||
if out_idx >= len(pid_lq_features):
|
||||
return s
|
||||
return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx)
|
||||
|
||||
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs):
|
||||
if lq_latent is None:
|
||||
raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
|
||||
expected_c = self.lq_proj.latent_channels
|
||||
if lq_latent.shape[1] != expected_c:
|
||||
raise ValueError(
|
||||
f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. "
|
||||
f"Flux1/SD3 = 16 channels, Flux2 = 128 channels."
|
||||
)
|
||||
B = x.shape[0]
|
||||
Hs = x.shape[2] // self.patch_size
|
||||
Ws = x.shape[3] // self.patch_size
|
||||
|
||||
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
|
||||
if degrade_sigma.numel() == 1 and B > 1:
|
||||
degrade_sigma = degrade_sigma.expand(B).contiguous()
|
||||
|
||||
lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
|
||||
|
||||
return super()._forward(
|
||||
x, timesteps,
|
||||
context=context, attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
pid_lq_features=lq_features,
|
||||
pid_degrade_sigma=degrade_sigma,
|
||||
**kwargs,
|
||||
)
|
||||
@ -49,8 +49,6 @@ import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.chroma_radiance.model
|
||||
import comfy.ldm.pixeldit.model
|
||||
import comfy.ldm.pixeldit.pid
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
@ -63,6 +61,7 @@ import comfy.ldm.ernie.model
|
||||
import comfy.ldm.sam3.detector
|
||||
import comfy.ldm.hidream_o1.model
|
||||
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
|
||||
import comfy.ldm.depth_anything_3.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -1399,36 +1398,6 @@ class ZImagePixelSpace(Lumina2):
|
||||
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
|
||||
class PixelDiTT2I(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
|
||||
return out
|
||||
|
||||
|
||||
class PiD(PixelDiTT2I):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
BaseModel.__init__(self, model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.pixeldit.pid.PidNet)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
lq_latent = kwargs.get("lq_latent", None)
|
||||
if lq_latent is not None:
|
||||
out["lq_latent"] = comfy.conds.CONDRegular(lq_latent)
|
||||
degrade_sigma = kwargs.get("degrade_sigma", None)
|
||||
if degrade_sigma is not None:
|
||||
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
|
||||
return out
|
||||
|
||||
|
||||
class WAN21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
@ -2175,6 +2144,12 @@ class RT_DETR_v4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||
|
||||
|
||||
class DepthAnything3(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.depth_anything_3.model.DepthAnything3Net)
|
||||
|
||||
class ErnieImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)
|
||||
|
||||
@ -463,23 +463,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||
return dit_config
|
||||
|
||||
# PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I.
|
||||
_lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix)
|
||||
if _lq_w_key in state_dict_keys:
|
||||
in_ch = int(state_dict[_lq_w_key].shape[1])
|
||||
_gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix)
|
||||
num_gates = len({k[len(_gate_prefix):].split('.')[0]
|
||||
for k in state_dict_keys if k.startswith(_gate_prefix)})
|
||||
dit_config = {"image_model": "pid",
|
||||
"lq_latent_channels": in_ch,
|
||||
"latent_spatial_down_factor": 16 if in_ch >= 64 else 8}
|
||||
if num_gates > 0:
|
||||
dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates
|
||||
return dit_config
|
||||
|
||||
if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I
|
||||
return {"image_model": "pixeldit_t2i"}
|
||||
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "lumina2"
|
||||
@ -846,6 +829,108 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
||||
return dit_config
|
||||
|
||||
# Depth Anything 3
|
||||
if '{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "DepthAnything3"
|
||||
|
||||
patch_w = state_dict['{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix)]
|
||||
embed_dim = patch_w.shape[0]
|
||||
depth = count_blocks(state_dict_keys, '{}backbone.pretrained.blocks.'.format(key_prefix) + '{}.')
|
||||
|
||||
# Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg).
|
||||
backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim)
|
||||
if backbone_name is None:
|
||||
return None
|
||||
dit_config["backbone_name"] = backbone_name
|
||||
|
||||
# Detect DA3 extensions on top of vanilla DINOv2.
|
||||
has_camera_token = '{}backbone.pretrained.camera_token'.format(key_prefix) in state_dict_keys
|
||||
# qk-norm shows up as `attn.q_norm.weight` on enabled blocks.
|
||||
qknorm_indices = [
|
||||
i for i in range(depth)
|
||||
if '{}backbone.pretrained.blocks.{}.attn.q_norm.weight'.format(key_prefix, i) in state_dict_keys
|
||||
]
|
||||
qknorm_start = qknorm_indices[0] if qknorm_indices else -1
|
||||
|
||||
# The DA3 main-series configs always set alt_start == qknorm_start == rope_start.
|
||||
# cat_token=True is implied by the presence of camera_token.
|
||||
if has_camera_token:
|
||||
dit_config["alt_start"] = qknorm_start
|
||||
dit_config["rope_start"] = qknorm_start
|
||||
dit_config["qknorm_start"] = qknorm_start
|
||||
dit_config["cat_token"] = True
|
||||
else:
|
||||
dit_config["alt_start"] = -1
|
||||
dit_config["rope_start"] = -1
|
||||
dit_config["qknorm_start"] = -1
|
||||
dit_config["cat_token"] = False
|
||||
|
||||
# Detect head type and config.
|
||||
has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys
|
||||
if has_aux:
|
||||
dit_config["head_type"] = "dualdpt"
|
||||
# DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width).
|
||||
head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
|
||||
out_channels = [
|
||||
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
|
||||
for i in range(4)
|
||||
]
|
||||
features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["head_dim_in"] = head_dim_in
|
||||
dit_config["head_output_dim"] = 2
|
||||
dit_config["head_features"] = features
|
||||
dit_config["head_out_channels"] = out_channels
|
||||
dit_config["head_use_sky_head"] = False
|
||||
else:
|
||||
dit_config["head_type"] = "dpt"
|
||||
head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
|
||||
out_channels = [
|
||||
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
|
||||
for i in range(4)
|
||||
]
|
||||
features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
|
||||
output_dim = state_dict[
|
||||
'{}head.scratch.output_conv2.2.weight'.format(key_prefix)
|
||||
].shape[0]
|
||||
dit_config["head_dim_in"] = head_dim_in
|
||||
dit_config["head_output_dim"] = output_dim
|
||||
dit_config["head_features"] = features
|
||||
dit_config["head_out_channels"] = out_channels
|
||||
dit_config["head_use_sky_head"] = (
|
||||
'{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys
|
||||
)
|
||||
|
||||
# out_layers: hard-coded per upstream YAML config (depth-aware default).
|
||||
if depth >= 24:
|
||||
# vitl: depths used vary between DA3-Large (DualDPT) and Mono/Metric (DPT).
|
||||
if has_aux:
|
||||
dit_config["out_layers"] = [11, 15, 19, 23]
|
||||
else:
|
||||
dit_config["out_layers"] = [4, 11, 17, 23]
|
||||
else:
|
||||
# vits/vitb: 12 blocks
|
||||
dit_config["out_layers"] = [5, 7, 9, 11]
|
||||
|
||||
# Camera encoder/decoder presence (multi-view + pose path).
|
||||
has_cam_enc = '{}cam_enc.token_norm.weight'.format(key_prefix) in state_dict_keys
|
||||
has_cam_dec = '{}cam_dec.fc_t.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["has_cam_enc"] = has_cam_enc
|
||||
dit_config["has_cam_dec"] = has_cam_dec
|
||||
if has_cam_enc:
|
||||
cam_enc_w = state_dict.get(
|
||||
'{}cam_enc.pose_branch.fc2.weight'.format(key_prefix)
|
||||
)
|
||||
if cam_enc_w is not None:
|
||||
dit_config["cam_dim_out"] = cam_enc_w.shape[0]
|
||||
if has_cam_dec:
|
||||
cam_dec_w = state_dict.get(
|
||||
'{}cam_dec.fc_t.weight'.format(key_prefix)
|
||||
)
|
||||
if cam_dec_w is not None:
|
||||
dit_config["cam_dec_dim_in"] = cam_dec_w.shape[1]
|
||||
return dit_config
|
||||
|
||||
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ernie"
|
||||
|
||||
10
comfy/sd.py
10
comfy/sd.py
@ -49,7 +49,6 @@ import comfy.text_encoders.lt
|
||||
import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.pixeldit
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
@ -1286,7 +1285,6 @@ class CLIPType(Enum):
|
||||
LONGCAT_IMAGE = 26
|
||||
COGVIDEOX = 27
|
||||
LENS = 28
|
||||
PIXELDIT = 29
|
||||
|
||||
|
||||
|
||||
@ -1530,12 +1528,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.tokenizer = variant.tokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
if clip_type == CLIPType.PIXELDIT:
|
||||
clip_target.clip = comfy.text_encoders.pixeldit.pixeldit_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.GEMMA_3_4B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
||||
|
||||
@ -30,7 +30,6 @@ import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.cogvideo
|
||||
import comfy.text_encoders.hidream_o1
|
||||
import comfy.text_encoders.pixeldit
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -845,8 +844,6 @@ class Lens(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux2
|
||||
|
||||
memory_usage_factor = 4.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
@ -1204,72 +1201,6 @@ class ZImagePixelSpace(ZImage):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ZImagePixelSpace(self, device=device)
|
||||
|
||||
class PixelDiTT2I(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "pixeldit_t2i",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
|
||||
}
|
||||
|
||||
latent_format = latent_formats.PixelDiTPixel
|
||||
memory_usage_factor = 0.04
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.PixelDiTT2I(self, device=device)
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
# pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim).
|
||||
pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0]
|
||||
|
||||
out = {}
|
||||
marker = ".adaLN_modulation.0."
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("_repa_projector") or k.startswith("net_ema."):
|
||||
continue
|
||||
if k.startswith("core."):
|
||||
k = k[len("core."):]
|
||||
elif k.startswith("net."):
|
||||
k = k[len("net."):]
|
||||
if "pixel_blocks." in k and marker in k:
|
||||
# Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
|
||||
p2 = v.shape[0] // (6 * pixel_dim)
|
||||
trail = v.shape[1:] # () for bias, (in_dim,) for weight
|
||||
vv = v.view(p2, 6, pixel_dim, *trail)
|
||||
base, suffix = k.split(marker)
|
||||
out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous()
|
||||
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous()
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer,
|
||||
comfy.text_encoders.pixeldit.PixelDiTGemma2TE,
|
||||
)
|
||||
|
||||
class PiD(PixelDiTT2I):
|
||||
unet_config = {
|
||||
"image_model": "pid",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
|
||||
}
|
||||
|
||||
memory_usage_factor = 0.04
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.PiD(self, device=device)
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1982,6 +1913,101 @@ class RT_DETR_v4(supported_models_base.BASE):
|
||||
return None
|
||||
|
||||
|
||||
class DepthAnything3(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "DepthAnything3",
|
||||
}
|
||||
|
||||
# Mono path: no num_heads / num_head_channels needed.
|
||||
unet_extra_config = {}
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.DepthAnything3(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
# Drop Gaussian-head weights; remap fused backbone QKV to Dinov2Model layout.
|
||||
drop_prefixes = ("gs_head.", "gs_adapter.")
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(drop_prefixes):
|
||||
state_dict.pop(k)
|
||||
return _da3_remap_backbone_keys(state_dict, prefix="backbone.")
|
||||
|
||||
|
||||
def _da3_remap_backbone_keys(state_dict, prefix="backbone."):
|
||||
"""Map ``backbone.pretrained.*`` (upstream DA3) keys to ``Dinov2Model`` under ``prefix``."""
|
||||
pre = prefix + "pretrained."
|
||||
src_keys = [k for k in state_dict.keys() if k.startswith(pre)]
|
||||
if not src_keys:
|
||||
return state_dict
|
||||
|
||||
static_renames = {
|
||||
pre + "patch_embed.proj.weight": prefix + "embeddings.patch_embeddings.projection.weight",
|
||||
pre + "patch_embed.proj.bias": prefix + "embeddings.patch_embeddings.projection.bias",
|
||||
pre + "pos_embed": prefix + "embeddings.position_embeddings",
|
||||
pre + "cls_token": prefix + "embeddings.cls_token",
|
||||
pre + "camera_token": prefix + "embeddings.camera_token",
|
||||
pre + "norm.weight": prefix + "layernorm.weight",
|
||||
pre + "norm.bias": prefix + "layernorm.bias",
|
||||
}
|
||||
for src, dst in static_renames.items():
|
||||
if src in state_dict:
|
||||
state_dict[dst] = state_dict.pop(src)
|
||||
|
||||
block_pre = pre + "blocks."
|
||||
block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)]
|
||||
for k in block_keys:
|
||||
rest = k[len(block_pre):] # e.g. "5.attn.qkv.weight"
|
||||
idx_str, _, sub = rest.partition(".")
|
||||
target_block = "{}encoder.layer.{}.".format(prefix, idx_str)
|
||||
|
||||
# Fused QKV -> split query/key/value linears.
|
||||
if sub == "attn.qkv.weight":
|
||||
qkv = state_dict.pop(k)
|
||||
c = qkv.shape[0] // 3
|
||||
state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone()
|
||||
state_dict[target_block + "attention.attention.key.weight"] = qkv[c:2 * c].clone()
|
||||
state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone()
|
||||
continue
|
||||
if sub == "attn.qkv.bias":
|
||||
qkv = state_dict.pop(k)
|
||||
c = qkv.shape[0] // 3
|
||||
state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone()
|
||||
state_dict[target_block + "attention.attention.key.bias"] = qkv[c:2 * c].clone()
|
||||
state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone()
|
||||
continue
|
||||
|
||||
# Sub-key remap (suffix preserved).
|
||||
if sub.startswith("attn.proj."):
|
||||
tail = sub[len("attn.proj."):]
|
||||
new = "attention.output.dense." + tail
|
||||
elif sub.startswith("attn.q_norm."):
|
||||
new = "attention.q_norm." + sub[len("attn.q_norm."):]
|
||||
elif sub.startswith("attn.k_norm."):
|
||||
new = "attention.k_norm." + sub[len("attn.k_norm."):]
|
||||
elif sub == "ls1.gamma":
|
||||
new = "layer_scale1.lambda1"
|
||||
elif sub == "ls2.gamma":
|
||||
new = "layer_scale2.lambda1"
|
||||
elif sub.startswith("mlp.w12."):
|
||||
new = "mlp.weights_in." + sub[len("mlp.w12."):]
|
||||
elif sub.startswith("mlp.w3."):
|
||||
new = "mlp.weights_out." + sub[len("mlp.w3."):]
|
||||
elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")):
|
||||
new = sub
|
||||
else:
|
||||
# Unrecognised key -- leave as-is so load_state_dict can complain.
|
||||
continue
|
||||
|
||||
state_dict[target_block + new] = state_dict.pop(k)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
class ErnieImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ernie",
|
||||
@ -2180,8 +2206,6 @@ models = [
|
||||
CosmosI2VPredict2,
|
||||
ZImagePixelSpace,
|
||||
ZImage,
|
||||
PiD,
|
||||
PixelDiTT2I,
|
||||
Lumina2,
|
||||
WAN22_T2V,
|
||||
WAN21_CausalAR_T2V,
|
||||
@ -2221,4 +2245,5 @@ models = [
|
||||
CogVideoX_I2V,
|
||||
CogVideoX_T2V,
|
||||
SVD_img2vid,
|
||||
DepthAnything3,
|
||||
]
|
||||
|
||||
@ -1,104 +0,0 @@
|
||||
import torch
|
||||
|
||||
from comfy import sd1_clip
|
||||
from .lumina2 import Gemma2BTokenizer, LuminaModel
|
||||
import comfy.text_encoders.llama
|
||||
|
||||
|
||||
class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(
|
||||
device=device, layer=layer, layer_idx=layer_idx,
|
||||
textmodel_json_config={}, dtype=dtype,
|
||||
special_tokens={"start": 2, "pad": 0},
|
||||
layer_norm_hidden_state=False,
|
||||
model_class=comfy.text_encoders.llama.Gemma2_2B,
|
||||
enable_attention_masks=attention_mask,
|
||||
return_attention_masks=attention_mask,
|
||||
model_options=model_options,
|
||||
)
|
||||
|
||||
|
||||
_PIXELDIT_CHI_PROMPT = (
|
||||
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions '
|
||||
"suitable for image generation. Evaluate the level of detail in the user prompt:\n"
|
||||
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, "
|
||||
"and spatial relationships to create vivid and concrete scenes.\n"
|
||||
"- If the prompt is already detailed, refine and enhance the existing details slightly without "
|
||||
"overcomplicating.\n"
|
||||
"Here are examples of how to transform or refine prompts:\n"
|
||||
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, "
|
||||
"sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
|
||||
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring "
|
||||
"glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus "
|
||||
"passing by towering glass skyscrapers.\n"
|
||||
"Please generate only the enhanced description for the prompt below and avoid including any "
|
||||
"additional commentary or evaluations:\n"
|
||||
"User Prompt: "
|
||||
)
|
||||
|
||||
_PIXELDIT_MAX_LENGTH = 300
|
||||
_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"'
|
||||
|
||||
|
||||
class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = {}
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||
name="gemma2_2b", tokenizer=Gemma2BTokenizer)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
if not text.strip():
|
||||
return super().tokenize_with_weights("", return_word_ids=return_word_ids, disable_weights=True, min_length=_PIXELDIT_MAX_LENGTH)
|
||||
|
||||
chi_token_count = len(self.gemma2_2b.tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"])
|
||||
combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text
|
||||
max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2
|
||||
out = super().tokenize_with_weights(combined, return_word_ids=return_word_ids,
|
||||
disable_weights=True, min_length=max_length_all)
|
||||
out["gemma2_2b"] = [out["gemma2_2b"][0][:max_length_all]]
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.gemma2_2b.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return self.gemma2_2b.state_dict()
|
||||
|
||||
|
||||
class PixelDiTGemma2TE(LuminaModel):
|
||||
# PixelDiT's select_index: keep BOS + last 299 embeddings of the padded sequence.
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="gemma2_2b",
|
||||
clip_model=PixelDiTGemma2_2BModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
result = super().encode_token_weights(token_weight_pairs)
|
||||
cond, pooled = result[0], result[1]
|
||||
extra = result[2] if len(result) > 2 else None
|
||||
if cond.shape[1] > _PIXELDIT_MAX_LENGTH:
|
||||
cond = torch.cat([cond[:, :1], cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]], dim=1)
|
||||
if extra is not None and "attention_mask" in extra:
|
||||
am = extra["attention_mask"]
|
||||
extra["attention_mask"] = torch.cat([am[..., :1], am[..., -(_PIXELDIT_MAX_LENGTH - 1):]], dim=-1)
|
||||
if extra is not None:
|
||||
return cond, pooled, extra
|
||||
return cond, pooled
|
||||
|
||||
|
||||
def pixeldit_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class PixelDiTTE_(PixelDiTGemma2TE):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return PixelDiTTE_
|
||||
@ -158,9 +158,8 @@ class SeedanceCreateAssetResponse(BaseModel):
|
||||
|
||||
|
||||
class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
|
||||
url: str = Field(..., description="Publicly accessible URL of the asset to upload.")
|
||||
url: str = Field(..., description="Publicly accessible URL of the image asset to upload.")
|
||||
hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.")
|
||||
asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.")
|
||||
|
||||
|
||||
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
"""Pydantic models for the Krea image-generation API."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class KreaMoodboard(BaseModel):
|
||||
id: str = Field(...)
|
||||
strength: float = Field(default=0.35, ge=-0.5, le=1.5)
|
||||
|
||||
|
||||
class KreaImageStyleReference(BaseModel):
|
||||
strength: float = Field(..., ge=-2.0, le=2.0)
|
||||
url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class KreaGenerateImageRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
aspect_ratio: str = Field(...)
|
||||
resolution: str = Field(...)
|
||||
seed: int | None = Field(default=None)
|
||||
creativity: str = Field(default="medium")
|
||||
moodboards: list[KreaMoodboard] | None = Field(default=None)
|
||||
image_style_references: list[KreaImageStyleReference] | None = Field(default=None)
|
||||
|
||||
|
||||
class KreaJobResult(BaseModel):
|
||||
urls: list[str] | None = Field(default=None)
|
||||
style_id: str | None = Field(default=None)
|
||||
|
||||
|
||||
class KreaJob(BaseModel):
|
||||
job_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
created_at: str = Field(...)
|
||||
completed_at: str | None = Field(default=None)
|
||||
result: KreaJobResult | None = Field(default=None)
|
||||
|
||||
|
||||
class KreaAssetResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
image_url: str = Field(...)
|
||||
uploaded_at: str = Field(...)
|
||||
width: float | None = Field(default=None)
|
||||
height: float | None = Field(default=None)
|
||||
size_bytes: float | None = Field(default=None)
|
||||
mime_type: str | None = Field(default=None)
|
||||
@ -2,12 +2,11 @@ import hashlib
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
@ -309,26 +308,6 @@ async def _seedance_virtual_library_upload_image_asset(
|
||||
return f"asset://{create_resp.asset_id}"
|
||||
|
||||
|
||||
async def _seedance_virtual_library_upload_video_asset(
|
||||
cls: type[IO.ComfyNode],
|
||||
video: Input.Video,
|
||||
*,
|
||||
wait_label: str = "Uploading video",
|
||||
) -> str:
|
||||
buf = BytesIO()
|
||||
video.save_to(buf, format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264)
|
||||
video_hash = hashlib.sha256(buf.getbuffer()).hexdigest()
|
||||
public_url = await upload_video_to_comfyapi(cls, video, wait_label=wait_label)
|
||||
create_resp = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/seedance/virtual-library/assets", method="POST"),
|
||||
response_model=SeedanceCreateAssetResponse,
|
||||
data=SeedanceVirtualLibraryCreateAssetRequest(url=public_url, hash=video_hash, asset_type="Video"),
|
||||
)
|
||||
await _wait_for_asset_active(cls, create_resp.asset_id, group_id="virtual-library")
|
||||
return f"asset://{create_resp.asset_id}"
|
||||
|
||||
|
||||
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
|
||||
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
|
||||
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
|
||||
@ -2127,7 +2106,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
content.append(
|
||||
TaskVideoContent(
|
||||
video_url=TaskVideoContentUrl(
|
||||
url=await _seedance_virtual_library_upload_video_asset(
|
||||
url=await upload_video_to_comfyapi(
|
||||
cls,
|
||||
reference_videos[key],
|
||||
wait_label=f"Uploading video {i}",
|
||||
|
||||
@ -1,290 +0,0 @@
|
||||
"""Krea image-generation nodes."""
|
||||
|
||||
import re
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.krea import (
|
||||
KreaAssetResponse,
|
||||
KreaGenerateImageRequest,
|
||||
KreaImageStyleReference,
|
||||
KreaJob,
|
||||
KreaMoodboard,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
class KreaIO:
|
||||
STYLE_REF = "KREA_STYLE_REF"
|
||||
|
||||
|
||||
async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Image) -> str:
|
||||
"""Upload an image to Krea's /assets endpoint and return the Krea-hosted image URL."""
|
||||
img_io = tensor_to_bytesio(image, total_pixels=2048 * 2048, mime_type="image/png")
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/krea/assets", method="POST"),
|
||||
response_model=KreaAssetResponse,
|
||||
files=[("file", (img_io.name, img_io, "image/png"))],
|
||||
content_type="multipart/form-data",
|
||||
max_retries=1,
|
||||
wait_label="Uploading reference",
|
||||
)
|
||||
return response.image_url
|
||||
|
||||
|
||||
_MODEL_MEDIUM = "Krea 2 Medium"
|
||||
_MODEL_LARGE = "Krea 2 Large"
|
||||
_MODEL_ENDPOINTS: dict[str, str] = {
|
||||
_MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium",
|
||||
_MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large",
|
||||
}
|
||||
|
||||
_ASPECT_RATIOS = ["1:1", "4:3", "3:2", "16:9", "2.35:1", "4:5", "2:3", "9:16"]
|
||||
_RESOLUTIONS = ["1K"]
|
||||
_CREATIVITY_LEVELS = ["raw", "low", "medium", "high"]
|
||||
_KREA_QUEUED_STATUSES = ["backlogged", "queued", "scheduled"]
|
||||
|
||||
_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
|
||||
|
||||
|
||||
def _krea_model_inputs() -> list:
|
||||
"""Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo."""
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=_ASPECT_RATIOS,
|
||||
tooltip="Output aspect ratio.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=_RESOLUTIONS,
|
||||
tooltip="Resolution scale.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"creativity",
|
||||
options=_CREATIVITY_LEVELS,
|
||||
default="medium",
|
||||
tooltip="Prompt interpretation strength: raw stays closest to the prompt; high is most creative.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"moodboard_id",
|
||||
default="",
|
||||
tooltip="Optional Krea moodboard UUID (e.g. from the Krea website). "
|
||||
"Leave empty to disable. Only one moodboard is supported per request.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"moodboard_strength",
|
||||
default=0.35,
|
||||
min=-0.5,
|
||||
max=1.5,
|
||||
step=0.05,
|
||||
tooltip="Moodboard influence; ignored when moodboard_id is empty.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Custom(KreaIO.STYLE_REF).Input(
|
||||
"style_reference",
|
||||
optional=True,
|
||||
tooltip="Optional chain of style references (max 10) from Krea 2 Style Reference nodes.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class Krea2ImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Krea2ImageNode",
|
||||
display_name="Krea 2 Image",
|
||||
category="api node/image/Krea",
|
||||
description=(
|
||||
"Generate images via Krea 2 — pick Medium (expressive illustrations) or "
|
||||
"Large (expressive photorealism). Supports an optional moodboard and up "
|
||||
"to 10 chained image style references."
|
||||
),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the image.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()),
|
||||
IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()),
|
||||
],
|
||||
tooltip="Krea 2 Medium is best for expressive illustrations; "
|
||||
"Krea 2 Large is best for expressive photorealism.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Random seed for reproducibility.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model", "model.moodboard_id"],
|
||||
inputs=["model.style_reference"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isLarge := widgets.model = "krea 2 large";
|
||||
$hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0;
|
||||
$hasStyle := $lookup(inputs, "model.style_reference").connected;
|
||||
$usd := $hasMoodboard
|
||||
? ($isLarge ? 0.07 : 0.04)
|
||||
: ($hasStyle
|
||||
? ($isLarge ? 0.065 : 0.035)
|
||||
: ($isLarge ? 0.06 : 0.03));
|
||||
{"type":"usd","usd": $usd}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||
|
||||
model_choice = model["model"]
|
||||
endpoint_path = _MODEL_ENDPOINTS.get(model_choice)
|
||||
if endpoint_path is None:
|
||||
raise ValueError(f"Unknown Krea 2 model: {model_choice!r}")
|
||||
|
||||
moodboards: list[KreaMoodboard] | None = None
|
||||
mb_id = (model.get("moodboard_id") or "").strip()
|
||||
if mb_id:
|
||||
if not _UUID_RE.match(mb_id):
|
||||
raise ValueError(f"moodboard_id must be a UUID (received {mb_id!r}); copy it from the Krea website.")
|
||||
mb_strength = model.get("moodboard_strength")
|
||||
moodboards = [KreaMoodboard(id=mb_id, strength=0.35 if mb_strength is None else float(mb_strength))]
|
||||
|
||||
style_reference = model.get("style_reference")
|
||||
image_style_references: list[KreaImageStyleReference] | None = None
|
||||
if style_reference:
|
||||
if len(style_reference) > 10:
|
||||
raise ValueError(f"Krea 2 accepts at most 10 image_style_references; received {len(style_reference)}.")
|
||||
image_style_references = [
|
||||
KreaImageStyleReference(url=ref["url"], strength=float(ref["strength"])) for ref in style_reference
|
||||
]
|
||||
initial = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=endpoint_path, method="POST"),
|
||||
response_model=KreaJob,
|
||||
data=KreaGenerateImageRequest(
|
||||
prompt=prompt,
|
||||
aspect_ratio=model["aspect_ratio"],
|
||||
resolution=model["resolution"],
|
||||
seed=seed,
|
||||
creativity=model["creativity"],
|
||||
moodboards=moodboards,
|
||||
image_style_references=image_style_references,
|
||||
),
|
||||
)
|
||||
job = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/krea/jobs/{initial.job_id}", method="GET"),
|
||||
response_model=KreaJob,
|
||||
status_extractor=lambda r: r.status,
|
||||
queued_statuses=_KREA_QUEUED_STATUSES,
|
||||
)
|
||||
if not job.result or not job.result.urls:
|
||||
raise RuntimeError(f"Krea 2 job {job.job_id} completed without any image URLs.")
|
||||
image = await download_url_to_image_tensor(job.result.urls[0])
|
||||
return IO.NodeOutput(image)
|
||||
|
||||
|
||||
class Krea2StyleReferenceNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Krea2StyleReferenceNode",
|
||||
display_name="Krea 2 Style Reference",
|
||||
category="api node/image/Krea",
|
||||
description=(
|
||||
"Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 "
|
||||
"Style Reference nodes (max 10) and feed the final `style_reference` output "
|
||||
"into Krea 2 Image. Each image is uploaded to ComfyAPI storage and passed as URL."
|
||||
),
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Reference image whose style influences the generation.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"strength",
|
||||
default=1.0,
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
step=0.05,
|
||||
tooltip="Reference strength; negative values invert the style influence.",
|
||||
),
|
||||
IO.Custom(KreaIO.STYLE_REF).Input(
|
||||
"style_reference",
|
||||
optional=True,
|
||||
tooltip="Optional incoming chain of style references; this node appends one more.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(KreaIO.STYLE_REF).Output(display_name="style_reference")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
strength: float,
|
||||
style_reference: list[dict] | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain: list[dict] = list(style_reference) if style_reference else []
|
||||
if len(chain) >= 10:
|
||||
raise ValueError("Krea 2 accepts at most 10 image_style_references in one generation.")
|
||||
url = await _upload_image_to_krea_assets(cls, image)
|
||||
chain.append({"url": url, "strength": float(strength)})
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class KreaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
Krea2ImageNode,
|
||||
Krea2StyleReferenceNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> KreaExtension:
|
||||
return KreaExtension()
|
||||
747
comfy_extras/nodes_depth_anything_3.py
Normal file
747
comfy_extras/nodes_depth_anything_3.py
Normal file
@ -0,0 +1,747 @@
|
||||
"""ComfyUI nodes for Depth Anything 3.
|
||||
|
||||
Adds these nodes:
|
||||
|
||||
* ``LoadDA3Model`` -- load a DA3 ``.safetensors`` file from the
|
||||
``models/geometry_estimation/`` folder.
|
||||
* ``DA3Inference`` -- unified depth estimation node supporting both mono and
|
||||
multi-view modes via a DynamicCombo selector. Returns a DA3_GEOMETRY dict of
|
||||
raw tensors (depth, sky, confidence, camera). Feed into ``DA3Render``
|
||||
to produce display images, or directly into ``MoGeRender`` for depth / mask views.
|
||||
* ``DA3Render`` -- post-processes a DA3_GEOMETRY dict: applies optional
|
||||
sky clipping, normalises depth and confidence, and returns display images.
|
||||
|
||||
Model capability matrix
|
||||
-----------------------
|
||||
Variant head_type has_sky has_conf cam_dec
|
||||
DA3-Small dualdpt False True yes
|
||||
DA3-Base dualdpt False True yes
|
||||
DA3-Mono-Large dpt True False no
|
||||
DA3-Metric-Large dpt True False no (raw output is metres)
|
||||
|
||||
The node raises a ``ValueError`` at execution time when the selected
|
||||
parameters conflict with the loaded model's capabilities (e.g.
|
||||
``apply_sky_clip=True`` on a model with no sky head).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management as mm
|
||||
import comfy.sd
|
||||
import folder_paths
|
||||
from comfy.ldm.colormap import turbo as _turbo
|
||||
from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess
|
||||
from comfy_api.latest import ComfyExtension, Types, io
|
||||
from comfy.ldm.moge.geometry import triangulate_grid_mesh
|
||||
|
||||
DA3ModelType = io.Custom("DA3_MODEL")
|
||||
DA3Geometry = io.Custom("DA3_GEOMETRY")
|
||||
DA3PointCloud = io.Custom("DA3_POINT_CLOUD")
|
||||
|
||||
# DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
|
||||
#
|
||||
# Per-frame tensors — B = batch size in mono mode; B = S (number of views) in multi-view mode.
|
||||
# "depth": torch.Tensor (B, H, W) -- raw model depth (always present; matches MoGe convention)
|
||||
# "image": torch.Tensor (B, H, W, 3) -- source image in [0, 1], CPU (always present)
|
||||
# "mode": str -- "mono" or "multiview" (always present)
|
||||
# "sky": torch.Tensor (B, H, W) -- sky probability in [0, 1] (Mono/Metric variants only)
|
||||
# "confidence": torch.Tensor (B, H, W) -- raw model confidence output (Small/Base variants only)
|
||||
#
|
||||
# Multi-view only — S = number of views; the leading 1 is the scene dimension from the model.
|
||||
# "extrinsics": torch.Tensor (1, S, 3, 4) -- world-to-camera [R|t] matrices
|
||||
# "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics
|
||||
#
|
||||
# DA3_POINT_CLOUD is a dict:
|
||||
# "points": torch.Tensor (N, 3) -- 3-D coords in glTF convention (Y-up, Z-back)
|
||||
# "colors": torch.Tensor (N, 3) -- RGB in [0, 1], or None
|
||||
# "confidence": torch.Tensor (N,) -- raw confidence per point, or None
|
||||
|
||||
|
||||
def _da3_unproject(depth: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
|
||||
"""Pixel-space K⁻¹ unprojection: (H,W) depth → (H,W,3) point map in OpenCV space."""
|
||||
H, W = depth.shape
|
||||
u = torch.arange(W, dtype=torch.float32, device=depth.device)
|
||||
v = torch.arange(H, dtype=torch.float32, device=depth.device)
|
||||
u, v = torch.meshgrid(u, v, indexing='xy') # both (H, W)
|
||||
pix = torch.stack([u, v, torch.ones_like(u)], dim=-1) # (H, W, 3)
|
||||
rays = torch.einsum('ij,hwj->hwi', torch.linalg.inv(K.to(depth.device)), pix)
|
||||
return rays * depth.unsqueeze(-1) # (H, W, 3)
|
||||
|
||||
|
||||
def _da3_default_K(H: int, W: int) -> torch.Tensor:
|
||||
"""Fallback ~60° FOV pinhole K for mono-mode DA3 (no intrinsics in geometry)."""
|
||||
fx = fy = float(W) * 0.7
|
||||
return torch.tensor([[fx, 0.0, (W - 1) / 2.0],
|
||||
[0.0, fy, (H - 1) / 2.0],
|
||||
[0.0, 0.0, 1.0]], dtype=torch.float32)
|
||||
|
||||
|
||||
def _da3_get_K(geometry: dict, b: int, H: int, W: int) -> torch.Tensor:
|
||||
"""Return pixel-space K for batch element b, falling back to a default estimate."""
|
||||
if "intrinsics" in geometry:
|
||||
# shape (1, S, 3, 3) — leading scene dimension from the multiview head
|
||||
return geometry["intrinsics"][0, b].float()
|
||||
logging.getLogger("comfy").warning(
|
||||
"DA3_GEOMETRY has no intrinsics (mono-mode model). "
|
||||
"Using a ~60° FOV estimate; 3-D reconstruction may be inaccurate."
|
||||
)
|
||||
return _da3_default_K(H, W)
|
||||
|
||||
|
||||
def _da3_get_extrinsic(geometry: dict, b: int) -> torch.Tensor | None:
|
||||
"""Return the world-to-camera extrinsic for batch element b, or None in mono mode.
|
||||
|
||||
The model outputs (1, S, 3, 4) [R|t] matrices; the fallback identity is (4, 4).
|
||||
_da3_apply_extrinsic handles both shapes via [:3, :3] / [:3, 3] slicing.
|
||||
"""
|
||||
if "extrinsics" not in geometry:
|
||||
return None
|
||||
return geometry["extrinsics"][0, b].float()
|
||||
|
||||
|
||||
def _da3_apply_extrinsic(points_cam: torch.Tensor, E: torch.Tensor) -> torch.Tensor:
|
||||
"""Transform (H,W,3) OpenCV camera-space points to world space.
|
||||
|
||||
E is the world-to-camera SE(3) matrix (3×4 or 4×4). The camera-to-world
|
||||
inverse is computed analytically as [Rᵀ | −Rᵀt] rather than via
|
||||
torch.linalg.inv to avoid numerical failures on near-degenerate poses.
|
||||
|
||||
Returns the original camera-space points unchanged if E contains non-finite
|
||||
values (failed pose estimation), so the node can still produce a mesh.
|
||||
"""
|
||||
E = E.to(points_cam.device).float()
|
||||
if not torch.isfinite(E).all():
|
||||
logging.getLogger("comfy").warning(
|
||||
"DA3 extrinsic matrix contains non-finite values (pose estimation may have failed). "
|
||||
"Falling back to camera-space coordinates."
|
||||
)
|
||||
return points_cam
|
||||
H, W, _ = points_cam.shape
|
||||
R = E[:3, :3] # (3, 3) rotation
|
||||
t = E[:3, 3] # (3,) translation
|
||||
R_inv = R.T # rotation inverse = transpose for orthogonal R
|
||||
t_inv = -(R_inv @ t) # (3,)
|
||||
pts = points_cam.reshape(-1, 3) # (N, 3)
|
||||
pts_world = pts @ R_inv.T + t_inv # (N, 3)
|
||||
return pts_world.reshape(H, W, 3)
|
||||
|
||||
|
||||
def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor:
|
||||
"""Map raw confidence (exp(x)+1 activation, range [1, ∞)) to [0, 1] per image.
|
||||
|
||||
Min-max per image preserves the spatial pattern while producing a [0, 1]
|
||||
value suitable for both display and masking.
|
||||
"""
|
||||
B = conf.shape[0]
|
||||
out = []
|
||||
for i in range(B):
|
||||
c = conf[i]
|
||||
c_min, c_max = c.min(), c.max()
|
||||
out.append((c - c_min) / (c_max - c_min) if c_max > c_min else torch.ones_like(c))
|
||||
return torch.stack(out, dim=0)
|
||||
|
||||
|
||||
def _da3_build_mask(geometry: dict, b: int, H: int, W: int,
|
||||
confidence_threshold: float, use_sky_mask: bool) -> torch.Tensor:
|
||||
"""Build (H,W) bool keep-mask from sky probability and confidence."""
|
||||
mask = torch.ones(H, W, dtype=torch.bool)
|
||||
if use_sky_mask and "sky" in geometry:
|
||||
mask = mask & (geometry["sky"][b] < 0.5)
|
||||
if "confidence" in geometry and confidence_threshold > 0.0:
|
||||
conf_norm = _normalize_confidence(geometry["confidence"][b:b + 1])[0]
|
||||
mask = mask & (conf_norm >= confidence_threshold)
|
||||
return mask
|
||||
|
||||
|
||||
class LoadDA3Model(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadDA3Model",
|
||||
display_name="Load Depth Anything 3",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"model_name",
|
||||
options=folder_paths.get_filename_list("geometry_estimation"),
|
||||
),
|
||||
io.Combo.Input(
|
||||
"weight_dtype",
|
||||
options=["default", "fp16", "bf16", "fp32"],
|
||||
default="default",
|
||||
),
|
||||
],
|
||||
outputs=[DA3ModelType.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name, weight_dtype) -> io.NodeOutput:
|
||||
model_options = {}
|
||||
if weight_dtype == "fp16":
|
||||
model_options["dtype"] = torch.float16
|
||||
elif weight_dtype == "bf16":
|
||||
model_options["dtype"] = torch.bfloat16
|
||||
elif weight_dtype == "fp32":
|
||||
model_options["dtype"] = torch.float32
|
||||
|
||||
path = folder_paths.get_full_path_or_raise("geometry_estimation", model_name)
|
||||
model = comfy.sd.load_diffusion_model(path, model_options=model_options)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
|
||||
method: str = "upper_bound_resize"):
|
||||
"""Run DA3 on ``(B,H,W,3)`` IMAGE; returns depth/conf/sky at original resolution (or None)."""
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, \
|
||||
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
|
||||
B, H, W, _ = image.shape
|
||||
mm.load_model_gpu(model_patcher)
|
||||
diffusion = model_patcher.model.diffusion_model
|
||||
device = mm.get_torch_device()
|
||||
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
|
||||
|
||||
depths, confs, skies = [], [], []
|
||||
for i in range(B):
|
||||
single = image[i:i + 1].to(device)
|
||||
x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method)
|
||||
x = x.to(dtype=dtype)
|
||||
with torch.no_grad():
|
||||
out = diffusion(x)
|
||||
|
||||
depth_lr = out["depth"]
|
||||
depth_full = torch.nn.functional.interpolate(
|
||||
depth_lr.unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
depths.append(depth_full)
|
||||
|
||||
if "depth_conf" in out:
|
||||
conf_full = torch.nn.functional.interpolate(
|
||||
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
confs.append(conf_full)
|
||||
if "sky" in out:
|
||||
sky_full = torch.nn.functional.interpolate(
|
||||
out["sky"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
skies.append(sky_full)
|
||||
|
||||
depth = torch.cat(depths, dim=0)
|
||||
confidence = torch.cat(confs, dim=0) if confs else None
|
||||
sky = torch.cat(skies, dim=0) if skies else None
|
||||
return depth, confidence, sky
|
||||
|
||||
|
||||
class DA3Inference(io.ComfyNode):
|
||||
"""Raw Depth Anything 3 inference node.
|
||||
|
||||
Outputs a DA3_GEOMETRY dict of raw tensors. All display normalization
|
||||
(sky clipping, depth scaling, confidence normalisation) is handled by
|
||||
the companion ``DA3Render`` node.
|
||||
|
||||
Mono mode: each batch element is processed independently.
|
||||
Multi-view mode: all frames share a single forward pass with cross-view
|
||||
attention; adds ``extrinsics`` and ``intrinsics`` to the geometry dict.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DA3Inference",
|
||||
search_aliases=["depth", "geometry", "da3", "depth anything", "monocular", "pointmap", "sky", "3d", "metric depth", "disparity"],
|
||||
display_name="Run Depth Anything 3",
|
||||
category="image/geometry_estimation",
|
||||
description="Run Depth Anything 3 on an image or image batch. In multi-view mode each frame is treated as a separate view of the same scene.",
|
||||
inputs=[
|
||||
DA3ModelType.Input("da3_model"),
|
||||
io.Image.Input("image",
|
||||
tooltip="Single image or image batch. "
|
||||
"In multi-view mode each frame is treated as "
|
||||
"a separate view of the same scene."),
|
||||
io.Int.Input("process_res", default=504, min=140, max=2520, step=14,
|
||||
tooltip="Resolution the model runs at (longest side, multiple of 14). "
|
||||
"Lower = faster / less VRAM; higher = more detail. "
|
||||
"Output is upsampled back to the original size."),
|
||||
io.Combo.Input("resize_method",
|
||||
options=["upper_bound_resize", "lower_bound_resize"],
|
||||
default="upper_bound_resize",
|
||||
tooltip="- upper_bound_resize: scale so the longest side = process_res (caps memory, default).\n"
|
||||
"- lower_bound_resize: scale so the shortest side = process_res (preserves more detail on tall/wide images, uses more memory)."),
|
||||
io.DynamicCombo.Input("mode",
|
||||
tooltip="- mono: single image or independent batch — works with any model variant.\n"
|
||||
"- multiview: all frames processed together for geometric consistency + camera pose — requires DA3-Small or DA3-Base (DA3-Mono-Large / DA3-Metric-Large do NOT support this mode).",
|
||||
options=[
|
||||
io.DynamicCombo.Option("mono", []),
|
||||
io.DynamicCombo.Option("multiview", [
|
||||
io.Combo.Input("ref_view_strategy",
|
||||
options=["saddle_balanced", "saddle_sim_range",
|
||||
"first", "middle"],
|
||||
default="saddle_balanced",
|
||||
tooltip="Which view to use as the geometric anchor (only applied when S >= 3 and no extrinsics are provided).\n"
|
||||
"- saddle_balanced: picks the view whose CLS-token features are closest to the median across similarity, norm and variance — best general choice.\n"
|
||||
"- saddle_sim_range: picks the view with the widest similarity spread to other views — favours the most distinct viewpoint.\n"
|
||||
"- first / middle: deterministic positional fallbacks."),
|
||||
io.Combo.Input("pose_method",
|
||||
options=["cam_dec", "ray_pose"],
|
||||
default="cam_dec",
|
||||
tooltip="- cam_dec: small MLP on the final camera token — works on DA3-Small and DA3-Base.\n"
|
||||
"- ray_pose: RANSAC over the DualDPT ray output — works on DA3-Small and DA3-Base.\n"
|
||||
"Both methods require DA3-Small or DA3-Base; this setting is ignored on Mono/Metric-Large."),
|
||||
]),
|
||||
]),
|
||||
],
|
||||
outputs=[
|
||||
DA3Geometry.Output("geometry",
|
||||
tooltip="DA3_GEOMETRY dict of raw tensors.\n"
|
||||
"- Always: 'depth' (B,H,W), 'image', 'mode'.\n"
|
||||
"- Optional: 'sky' + 'mask' (Mono/Metric), 'confidence' raw (Small/Base), 'extrinsics' + 'intrinsics' (multi-view)."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_model, image, process_res, resize_method, mode) -> io.NodeOutput:
|
||||
mode_val = mode["mode"] # "mono" or "multiview"
|
||||
|
||||
if mode_val == "mono":
|
||||
return cls._execute_mono(da3_model, image, process_res, resize_method)
|
||||
|
||||
# Capability checks for multi-view mode.
|
||||
diffusion = da3_model.model.diffusion_model
|
||||
pose_method = mode["pose_method"]
|
||||
ref_view_strategy = mode["ref_view_strategy"]
|
||||
|
||||
has_cam_dec = diffusion.cam_dec is not None
|
||||
has_dualdpt = diffusion.head_type == "dualdpt"
|
||||
|
||||
if not has_cam_dec and not has_dualdpt:
|
||||
raise ValueError(
|
||||
"multiview mode requires DA3-Small or DA3-Base — the loaded model "
|
||||
f"(head_type='{diffusion.head_type}') does not support cross-view "
|
||||
"attention or camera pose estimation. Switch mode to 'mono', or "
|
||||
"load DA3-Small / DA3-Base for multiview."
|
||||
)
|
||||
|
||||
if pose_method == "cam_dec" and not has_cam_dec:
|
||||
raise ValueError(
|
||||
"pose_method='cam_dec' requires a camera decoder, but the loaded "
|
||||
f"model (head_type='{diffusion.head_type}') does not have one. "
|
||||
"Use pose_method='ray_pose' instead."
|
||||
)
|
||||
if pose_method == "ray_pose" and not has_dualdpt:
|
||||
raise ValueError(
|
||||
"pose_method='ray_pose' requires a DualDPT head, but the loaded "
|
||||
f"model has a '{diffusion.head_type}' head. "
|
||||
"Use pose_method='cam_dec' instead."
|
||||
)
|
||||
|
||||
return cls._execute_multiview(
|
||||
da3_model, image, process_res, resize_method,
|
||||
ref_view_strategy, pose_method,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _execute_mono(cls, model, image, process_res, resize_method) -> io.NodeOutput:
|
||||
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
|
||||
|
||||
geometry: dict = {
|
||||
"depth": depth.contiguous(),
|
||||
"image": image[..., :3].cpu(),
|
||||
"mode": "mono",
|
||||
}
|
||||
if sky is not None:
|
||||
geometry["sky"] = sky.contiguous()
|
||||
if confidence is not None:
|
||||
geometry["confidence"] = confidence.contiguous()
|
||||
return io.NodeOutput(geometry)
|
||||
|
||||
@classmethod
|
||||
def _execute_multiview(cls, model, image, process_res, resize_method,
|
||||
ref_view_strategy, pose_method) -> io.NodeOutput:
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, \
|
||||
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
S, H, W, _ = image.shape
|
||||
|
||||
mm.load_model_gpu(model)
|
||||
diffusion = model.model.diffusion_model
|
||||
device = mm.get_torch_device()
|
||||
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
|
||||
|
||||
# All views in a single forward pass: (1, S, 3, H', W').
|
||||
x = image.to(device)
|
||||
x = da3_preprocess.preprocess_image(x, process_res=process_res, method=resize_method)
|
||||
x = x.to(dtype=dtype).unsqueeze(0)
|
||||
|
||||
use_ray_pose = (pose_method == "ray_pose")
|
||||
with torch.no_grad():
|
||||
out = diffusion(x, use_ray_pose=use_ray_pose,
|
||||
ref_view_strategy=ref_view_strategy)
|
||||
|
||||
depth = torch.nn.functional.interpolate(
|
||||
out["depth"].float().unsqueeze(1), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
|
||||
sky = None
|
||||
if "sky" in out:
|
||||
sky = torch.nn.functional.interpolate(
|
||||
out["sky"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
|
||||
if "extrinsics" in out and "intrinsics" in out:
|
||||
extrinsics = out["extrinsics"].float().cpu()
|
||||
intrinsics = out["intrinsics"].float().cpu()
|
||||
else:
|
||||
extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone()
|
||||
intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone()
|
||||
|
||||
geometry: dict = {
|
||||
"depth": depth.contiguous(),
|
||||
"image": image[..., :3].cpu(),
|
||||
"mode": "multiview",
|
||||
"extrinsics": extrinsics.contiguous(),
|
||||
"intrinsics": intrinsics.contiguous(),
|
||||
}
|
||||
if sky is not None:
|
||||
geometry["sky"] = sky.contiguous()
|
||||
if "depth_conf" in out:
|
||||
conf = torch.nn.functional.interpolate(
|
||||
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
geometry["confidence"] = conf.contiguous()
|
||||
return io.NodeOutput(geometry)
|
||||
|
||||
|
||||
|
||||
|
||||
class DA3Render(io.ComfyNode):
|
||||
"""Visualise a DA3_GEOMETRY packet as a single image.
|
||||
|
||||
Mirrors the MoGeRender interface: one ``output`` selector, one IMAGE out.
|
||||
Use multiple nodes in parallel to get depth + sky + confidence simultaneously.
|
||||
"""
|
||||
|
||||
_DEPTH_RENDER_INPUTS = [
|
||||
io.Combo.Input("normalization",
|
||||
options=["v2_style", "min_max", "raw"],
|
||||
default="v2_style",
|
||||
tooltip="- v2_style: mean/std normalisation for perceptually balanced results (default).\n"
|
||||
"- min_max: stretches the full depth range to [0, 1] for maximum contrast.\n"
|
||||
"- raw: no scaling — preserves metric units for DA3-Metric-Large."),
|
||||
io.Boolean.Input("apply_sky_clip", default=False,
|
||||
tooltip="Clip sky-region depth to the 99th percentile of foreground depth before "
|
||||
"normalisation. Requires a 'sky' tensor in the geometry "
|
||||
"(DA3-Mono-Large or DA3-Metric-Large); raises an error otherwise."),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DA3Render",
|
||||
display_name="Depth Anything 3 Render",
|
||||
category="image/geometry_estimation",
|
||||
description="Visualise a DA3_GEOMETRY packet. Drop multiple nodes to get different views simultaneously.",
|
||||
inputs=[
|
||||
DA3Geometry.Input("geometry"),
|
||||
io.DynamicCombo.Input("output",
|
||||
tooltip="- depth: normalised greyscale depth image.\n"
|
||||
"- depth_colored: depth mapped through the Turbo colormap.\n"
|
||||
"- sky_mask: sky probability in [0, 1] (Mono/Metric variants only).\n"
|
||||
"- confidence: normalised depth confidence (Small/Base variants only).",
|
||||
options=[
|
||||
io.DynamicCombo.Option("depth", cls._DEPTH_RENDER_INPUTS),
|
||||
io.DynamicCombo.Option("depth_colored", cls._DEPTH_RENDER_INPUTS),
|
||||
io.DynamicCombo.Option("sky_mask", []),
|
||||
io.DynamicCombo.Option("confidence", []),
|
||||
]),
|
||||
],
|
||||
outputs=[io.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, geometry, output) -> io.NodeOutput:
|
||||
output_val = output["output"]
|
||||
|
||||
if output_val in ("depth", "depth_colored"):
|
||||
normalization = output["normalization"]
|
||||
apply_sky_clip = output["apply_sky_clip"]
|
||||
if apply_sky_clip and "sky" not in geometry:
|
||||
raise ValueError(
|
||||
"apply_sky_clip=True requires a sky tensor in the geometry, but none is present. "
|
||||
"Run with DA3-Mono-Large or DA3-Metric-Large, or set apply_sky_clip=False."
|
||||
)
|
||||
depth = geometry["depth"]
|
||||
sky = geometry.get("sky")
|
||||
if apply_sky_clip and sky is not None:
|
||||
depth = torch.stack([
|
||||
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
|
||||
for i in range(depth.shape[0])
|
||||
], dim=0)
|
||||
grey = cls._depth_to_image(depth, sky, normalization) # (B,H,W,3) greyscale
|
||||
result = _turbo(grey[..., 0]) if output_val == "depth_colored" else grey
|
||||
|
||||
elif output_val == "sky_mask":
|
||||
if "sky" not in geometry:
|
||||
raise ValueError("geometry has no sky output; run with DA3-Mono-Large or DA3-Metric-Large.")
|
||||
sky = geometry["sky"]
|
||||
result = sky.unsqueeze(-1).expand(*sky.shape, 3).contiguous()
|
||||
|
||||
elif output_val == "confidence":
|
||||
if "confidence" not in geometry:
|
||||
raise ValueError("geometry has no confidence output; run with DA3-Small or DA3-Base.")
|
||||
result = _normalize_confidence(geometry["confidence"])
|
||||
result = result.unsqueeze(-1).expand(*result.shape, 3).contiguous()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown output mode: {output_val}")
|
||||
|
||||
return io.NodeOutput(result.float())
|
||||
|
||||
@staticmethod
|
||||
def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None,
|
||||
normalization: str) -> torch.Tensor:
|
||||
"""Normalise depth and pack as an (B,H,W,3) image tensor."""
|
||||
N = depth.shape[0]
|
||||
if normalization == "v2_style":
|
||||
norm = torch.stack([
|
||||
da3_preprocess.normalize_depth_v2_style(
|
||||
depth[i], sky_for_norm[i] if sky_for_norm is not None else None)
|
||||
for i in range(N)
|
||||
], dim=0)
|
||||
elif normalization == "min_max":
|
||||
norm = da3_preprocess.normalize_depth_min_max(depth)
|
||||
else:
|
||||
norm = depth
|
||||
|
||||
out = norm.unsqueeze(-1).repeat(1, 1, 1, 3)
|
||||
if normalization != "raw":
|
||||
out = out.clamp(0.0, 1.0)
|
||||
return out.contiguous()
|
||||
|
||||
|
||||
|
||||
|
||||
class DA3GeometryToMesh(io.ComfyNode):
|
||||
"""Convert a DA3_GEOMETRY packet into a Types.MESH by unprojecting depth and triangulating."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DA3GeometryToMesh",
|
||||
search_aliases=["da3", "depth anything", "mesh", "geometry", "3d", "triangulate"],
|
||||
display_name="DA3 Geometry to Mesh",
|
||||
category="image/geometry_estimation",
|
||||
description="Convert a DA3_GEOMETRY depth map into a triangulated 3D mesh (Types.MESH).",
|
||||
inputs=[
|
||||
DA3Geometry.Input("da3_geometry"),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096,
|
||||
tooltip="Which frame of a batched DA3_GEOMETRY to mesh. "
|
||||
"Per-frame vertex counts differ so batches cannot be stacked."),
|
||||
io.Int.Input("decimation", default=1, min=1, max=8,
|
||||
tooltip="Vertex stride; 1 = full resolution, 2 = half, etc."),
|
||||
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Drop triangles whose 3×3 depth span exceeds this fraction. 0 = off."),
|
||||
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all, 1 = keep only the single most confident pixel). "
|
||||
"Ignored when the geometry has no confidence map (Mono/Metric models)."),
|
||||
io.Boolean.Input("use_sky_mask", default=True,
|
||||
tooltip="Exclude sky-probability pixels (sky >= 0.5) from the mesh. "
|
||||
"Ignored when the geometry has no sky map (Small/Base models)."),
|
||||
io.Boolean.Input("texture", default=True,
|
||||
tooltip="Carry the source image through as the baseColor texture."),
|
||||
],
|
||||
outputs=[io.Mesh.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_geometry, batch_index, decimation, discontinuity_threshold,
|
||||
confidence_threshold, use_sky_mask, texture) -> io.NodeOutput:
|
||||
depth_all = da3_geometry["depth"] # (B, H, W)
|
||||
B = depth_all.shape[0]
|
||||
if batch_index >= B:
|
||||
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
|
||||
|
||||
depth = depth_all[batch_index] # (H, W)
|
||||
H, W = depth.shape
|
||||
|
||||
# NaN/inf depth would propagate silently through unproject and produce an
|
||||
# empty mesh; replace them with 0 here so those pixels are later excluded
|
||||
# by the isfinite check inside triangulate_grid_mesh.
|
||||
depth = depth.clone()
|
||||
n_bad = (~torch.isfinite(depth)).sum().item()
|
||||
if n_bad:
|
||||
logging.getLogger("comfy").warning(
|
||||
f"DA3GeometryToMesh: depth[{batch_index}] has {n_bad} non-finite pixels "
|
||||
f"({100*n_bad/(H*W):.1f}%) — zeroed before unproject."
|
||||
)
|
||||
depth[~torch.isfinite(depth)] = 0.0
|
||||
logging.getLogger("comfy").debug(
|
||||
f"DA3GeometryToMesh: depth[{batch_index}] range "
|
||||
f"[{depth.min():.4g}, {depth.max():.4g}], mean={depth.mean():.4g}"
|
||||
)
|
||||
|
||||
K = _da3_get_K(da3_geometry, batch_index, H, W)
|
||||
points = _da3_unproject(depth, K) # (H, W, 3) in OpenCV camera space
|
||||
|
||||
# Apply world-to-camera inverse so multi-view frames share a common world frame.
|
||||
E = _da3_get_extrinsic(da3_geometry, batch_index)
|
||||
if E is not None:
|
||||
points = _da3_apply_extrinsic(points, E)
|
||||
|
||||
# Mask invalid pixels by setting them to inf so triangulate_grid_mesh skips them.
|
||||
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
|
||||
# Also exclude pixels where depth was invalid.
|
||||
mask = mask & (depth_all[batch_index] > 0) & torch.isfinite(depth_all[batch_index])
|
||||
points = points.clone()
|
||||
points[~mask] = float('inf')
|
||||
|
||||
verts, faces, uvs = triangulate_grid_mesh(
|
||||
points,
|
||||
decimation=decimation,
|
||||
discontinuity_threshold=discontinuity_threshold,
|
||||
depth=depth,
|
||||
)
|
||||
if verts.shape[0] == 0 or faces.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"DA3GeometryToMesh produced an empty mesh. "
|
||||
"Try raising discontinuity_threshold, lowering confidence_threshold, "
|
||||
"or disabling use_sky_mask."
|
||||
)
|
||||
|
||||
# OpenCV (X right, Y down, Z forward) → glTF (X right, Y up, Z back).
|
||||
# Same transform as MoGePointMapToMesh perspective branch.
|
||||
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
|
||||
faces = faces[:, [0, 2, 1]].contiguous()
|
||||
|
||||
tex = da3_geometry["image"][batch_index:batch_index + 1] if texture else None
|
||||
mesh = Types.MESH(
|
||||
vertices=verts.unsqueeze(0),
|
||||
faces=faces.unsqueeze(0),
|
||||
uvs=uvs.unsqueeze(0),
|
||||
texture=tex,
|
||||
)
|
||||
return io.NodeOutput(mesh)
|
||||
|
||||
|
||||
class DA3GeometryToPointCloud(io.ComfyNode):
|
||||
"""Unproject a DA3_GEOMETRY depth map into a filtered DA3_POINT_CLOUD."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DA3GeometryToPointCloud",
|
||||
search_aliases=["da3", "depth anything", "point cloud", "pointcloud", "3d", "geometry"],
|
||||
display_name="DA3 Geometry to Point Cloud",
|
||||
category="image/geometry_estimation",
|
||||
description="Unproject a DA3_GEOMETRY depth map into a 3D point cloud (DA3_POINT_CLOUD).",
|
||||
inputs=[
|
||||
DA3Geometry.Input("da3_geometry"),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096,
|
||||
tooltip="Which frame of a batched DA3_GEOMETRY to convert."),
|
||||
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all). "
|
||||
"Ignored when the geometry has no confidence map."),
|
||||
io.Boolean.Input("use_sky_mask", default=True,
|
||||
tooltip="Exclude sky-probability pixels (sky >= 0.5). "
|
||||
"Ignored when the geometry has no sky map."),
|
||||
io.Int.Input("downsample", default=1, min=1, max=16,
|
||||
tooltip="Take every Nth pixel (1 = full resolution). "
|
||||
"Higher values give fewer points and faster processing."),
|
||||
],
|
||||
# TODO: add a proper PointCloud output type
|
||||
outputs=[DA3PointCloud.Output(display_name="point_cloud")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_geometry, batch_index, confidence_threshold,
|
||||
use_sky_mask, downsample) -> io.NodeOutput:
|
||||
depth_all = da3_geometry["depth"] # (B, H, W)
|
||||
B = depth_all.shape[0]
|
||||
if batch_index >= B:
|
||||
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
|
||||
|
||||
depth = depth_all[batch_index].clone() # (H, W)
|
||||
depth[~torch.isfinite(depth)] = 0.0
|
||||
H, W = depth.shape
|
||||
|
||||
K = _da3_get_K(da3_geometry, batch_index, H, W)
|
||||
|
||||
if downsample > 1:
|
||||
depth = depth[::downsample, ::downsample].contiguous()
|
||||
# Scale intrinsics to the downsampled grid.
|
||||
K = K.clone()
|
||||
K[0, :] /= downsample
|
||||
K[1, :] /= downsample
|
||||
|
||||
H_ds, W_ds = depth.shape
|
||||
points = _da3_unproject(depth, K) # (H_ds, W_ds, 3) in OpenCV camera space
|
||||
|
||||
# Apply world-to-camera inverse so multi-view frames share a common world frame.
|
||||
E = _da3_get_extrinsic(da3_geometry, batch_index)
|
||||
if E is not None:
|
||||
points = _da3_apply_extrinsic(points, E)
|
||||
|
||||
# Rebuild mask at downsampled resolution.
|
||||
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
|
||||
if downsample > 1:
|
||||
mask = mask[::downsample, ::downsample]
|
||||
|
||||
mask = mask & torch.isfinite(depth)
|
||||
|
||||
# OpenCV → glTF: flip Y and Z.
|
||||
points_gltf = points.clone()
|
||||
points_gltf[..., 1] *= -1.0
|
||||
points_gltf[..., 2] *= -1.0
|
||||
|
||||
pts_flat = points_gltf.reshape(-1, 3)[mask.reshape(-1)]
|
||||
|
||||
colors_flat = None
|
||||
if "image" in da3_geometry:
|
||||
img = da3_geometry["image"][batch_index] # (H, W, 3)
|
||||
if downsample > 1:
|
||||
img = img[::downsample, ::downsample]
|
||||
colors_flat = img.reshape(-1, 3)[mask.reshape(-1)]
|
||||
|
||||
conf_flat = None
|
||||
if "confidence" in da3_geometry:
|
||||
conf = da3_geometry["confidence"][batch_index] # (H, W)
|
||||
if downsample > 1:
|
||||
conf = conf[::downsample, ::downsample]
|
||||
conf_flat = conf.reshape(-1)[mask.reshape(-1)]
|
||||
|
||||
if pts_flat.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"DA3GeometryToPointCloud produced zero points after filtering. "
|
||||
"Try lowering confidence_threshold or disabling use_sky_mask."
|
||||
)
|
||||
|
||||
return io.NodeOutput({
|
||||
"points": pts_flat,
|
||||
"colors": colors_flat,
|
||||
"confidence": conf_flat,
|
||||
})
|
||||
|
||||
|
||||
class DA3Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LoadDA3Model,
|
||||
DA3Inference,
|
||||
DA3Render,
|
||||
DA3GeometryToMesh,
|
||||
# DA3GeometryToPointCloud, # Keep this commented out for now until we have a proper PointCloud output type
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> DA3Extension:
|
||||
return DA3Extension()
|
||||
@ -226,20 +226,10 @@ def get_noise_mask(latent):
|
||||
noise_mask = noise_mask.clone()
|
||||
return noise_mask
|
||||
|
||||
def get_keyframe_idxs(cond, latent_shape=None):
|
||||
def get_keyframe_idxs(cond):
|
||||
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||
if keyframe_idxs is None:
|
||||
return None, 0
|
||||
# Get number of keyframes from latent_shape or guide_attention_entries if available
|
||||
if latent_shape is not None and len(latent_shape) == 5:
|
||||
tokens_per_frame = latent_shape[-2] * latent_shape[-1]
|
||||
num_keyframes = keyframe_idxs.shape[2] // tokens_per_frame
|
||||
return keyframe_idxs, num_keyframes
|
||||
entries = conditioning_get_any_value(cond, "guide_attention_entries", None)
|
||||
if entries:
|
||||
num_keyframes = sum(e["latent_shape"][0] for e in entries)
|
||||
return keyframe_idxs, num_keyframes
|
||||
# fallback, may under-count if keyframes share t-start
|
||||
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
|
||||
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
|
||||
return keyframe_idxs, num_keyframes
|
||||
@ -332,9 +322,9 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return factor
|
||||
|
||||
@classmethod
|
||||
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors, latent_shape=None):
|
||||
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||
time_scale_factor, _, _ = scale_factors
|
||||
_, num_keyframes = get_keyframe_idxs(cond, latent_shape)
|
||||
_, num_keyframes = get_keyframe_idxs(cond)
|
||||
latent_count = latent_length - num_keyframes
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
if guide_length > 1 and frame_idx != 0:
|
||||
@ -446,7 +436,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
|
||||
resolved_frame_idx = frame_idx
|
||||
if frame_idx < 0:
|
||||
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
|
||||
_, num_keyframes = get_keyframe_idxs(positive)
|
||||
resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1
|
||||
|
||||
@ -464,7 +454,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
if latent_downscale_factor > 1:
|
||||
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
|
||||
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors, latent_shape=latent_image.shape)
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
||||
@ -516,7 +506,7 @@ class LTXVCropGuides(io.ComfyNode):
|
||||
latent_image = latent["samples"].clone()
|
||||
noise_mask = get_noise_mask(latent)
|
||||
|
||||
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
|
||||
_, num_keyframes = get_keyframe_idxs(positive)
|
||||
if num_keyframes == 0:
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ import folder_paths
|
||||
from comfy_api.latest import ComfyExtension, Types, io
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.ldm.colormap import turbo as _turbo
|
||||
from comfy.ldm.moge.model import MoGeModel
|
||||
from comfy.ldm.moge.geometry import triangulate_grid_mesh
|
||||
from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid
|
||||
@ -27,19 +28,6 @@ MoGeGeometry = io.Custom("MOGE_GEOMETRY")
|
||||
# "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present)
|
||||
|
||||
|
||||
def _turbo(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Anton Mikhailov polynomial approximation of the turbo colormap."""
|
||||
x = x.clamp(0.0, 1.0)
|
||||
x2 = x * x
|
||||
x3 = x2 * x
|
||||
x4 = x2 * x2
|
||||
x5 = x4 * x
|
||||
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
|
||||
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
|
||||
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
|
||||
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
|
||||
|
||||
|
||||
def _normals_from_points(points: torch.Tensor) -> torch.Tensor:
|
||||
"""Camera-space surface normals from a (B, H, W, 3) point map (v1 fallback)."""
|
||||
finite = torch.isfinite(points).all(dim=-1)
|
||||
|
||||
@ -1,55 +0,0 @@
|
||||
"""PiD (Pixel Diffusion Decoder) node"""
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import node_helpers
|
||||
import comfy.latent_formats
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class PiDConditioning(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="PiDConditioning",
|
||||
display_name="PiD Conditioning",
|
||||
category="advanced/conditioning",
|
||||
description=(
|
||||
"Attaches a latent and a degrade_sigma scalar to a CONDITIONING for PiD decoding/upscaling"
|
||||
),
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Latent.Input("latent", tooltip="latent (from VAEEncode or a KSampler)."),
|
||||
io.Combo.Input("latent_format", options=["flux", "sd3"], default="flux",
|
||||
tooltip="Flux1 and Flux2 latents auto-detected from channel dim, sd3 has to be selected manually."),
|
||||
io.Float.Input(
|
||||
"degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="0 = clean latent. Increase to denoise corrupted latent outputs.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, latent, latent_format: str, degrade_sigma: float) -> io.NodeOutput:
|
||||
samples = latent["samples"]
|
||||
if latent_format == "flux":
|
||||
fmt_cls = comfy.latent_formats.Flux2 if samples.shape[1] == 128 else comfy.latent_formats.Flux
|
||||
else:
|
||||
fmt_cls = comfy.latent_formats.SD3
|
||||
lq_latent = fmt_cls().process_in(samples)
|
||||
sigma_t = torch.tensor([float(degrade_sigma)], dtype=torch.float32)
|
||||
return io.NodeOutput(node_helpers.conditioning_set_values(
|
||||
positive, {"lq_latent": lq_latent, "degrade_sigma": sigma_t},
|
||||
))
|
||||
|
||||
|
||||
class PiDExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [PiDConditioning]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> PiDExtension:
|
||||
return PiDExtension()
|
||||
6
nodes.py
6
nodes.py
@ -969,7 +969,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -979,7 +979,7 @@ class CLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B\nlens: gpt-oss-20b\n pixeldit: gemma 2 2B elm"
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B\nlens: gpt-oss-20b"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
@ -2420,7 +2420,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_context_windows.py",
|
||||
"nodes_qwen.py",
|
||||
"nodes_chroma_radiance.py",
|
||||
"nodes_pid.py",
|
||||
"nodes_model_patch.py",
|
||||
"nodes_easycache.py",
|
||||
"nodes_audio_encoder.py",
|
||||
@ -2455,6 +2454,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_save_3d.py",
|
||||
"nodes_moge.py",
|
||||
"nodes_mediapipe.py",
|
||||
"nodes_depth_anything_3.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
74
openapi.yaml
74
openapi.yaml
@ -275,10 +275,7 @@ paths:
|
||||
responses:
|
||||
"200":
|
||||
description: Queue updated
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/QueueManageResponse"
|
||||
|
||||
'400':
|
||||
description: Invalid request parameters
|
||||
content:
|
||||
@ -3095,34 +3092,18 @@ paths:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- asset_ids
|
||||
properties:
|
||||
job_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: Job IDs whose associated assets should all be included in the ZIP bundle.
|
||||
asset_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
format: uuid
|
||||
description: Asset IDs to include in the ZIP bundle. Additive to assets associated with provided job IDs.
|
||||
description: IDs of assets to export
|
||||
export_name:
|
||||
type: string
|
||||
description: Name for the export archive
|
||||
naming_strategy:
|
||||
type: string
|
||||
enum: [group_by_job_id, preserve, asset_id, group_by_job_time]
|
||||
default: group_by_job_time
|
||||
description: "Strategy for naming files in the ZIP: group by job ID, preserve original names, use the asset ID, or group by job creation time."
|
||||
job_asset_name_filters:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: array
|
||||
minItems: 1
|
||||
items:
|
||||
type: string
|
||||
description: Optional per-job asset name filters. When provided for a job ID, only assets whose name matches one of the listed names are included.
|
||||
responses:
|
||||
"202":
|
||||
description: Export task accepted
|
||||
@ -3594,7 +3575,10 @@ paths:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/HubLabelListResponse"
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/HubLabel"
|
||||
|
||||
'400':
|
||||
description: Bad request (e.g. invalid type parameter)
|
||||
content:
|
||||
@ -7482,25 +7466,6 @@ components:
|
||||
type: string
|
||||
description: Array of prompt IDs to delete from queue
|
||||
|
||||
QueueManageResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: >-
|
||||
[cloud-only] Result of a queue mutation. The Cloud runtime returns which
|
||||
items were deleted and whether the queue was cleared; local ComfyUI
|
||||
returns an empty 200 body.
|
||||
properties:
|
||||
deleted:
|
||||
type: array
|
||||
nullable: true
|
||||
items:
|
||||
type: string
|
||||
description: Prompt IDs that were deleted from the queue.
|
||||
cleared:
|
||||
type: boolean
|
||||
nullable: true
|
||||
description: Whether the queue was cleared.
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# History
|
||||
# -------------------------------------------------------------------
|
||||
@ -7581,16 +7546,6 @@ components:
|
||||
outputs_count:
|
||||
type: integer
|
||||
description: Total number of output files
|
||||
workflow_id:
|
||||
type: string
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] UUID of the Cloud workflow entity this job is associated with. Local ComfyUI returns null."
|
||||
execution_error:
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Detailed execution error from ComfyUI for failed jobs. Absent on local ComfyUI."
|
||||
allOf:
|
||||
- $ref: "#/components/schemas/ExecutionError"
|
||||
|
||||
JobDetailResponse:
|
||||
type: object
|
||||
@ -10478,19 +10433,6 @@ components:
|
||||
- custom_node
|
||||
description: Label category.
|
||||
|
||||
HubLabelListResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: '[cloud-only] Response wrapper for the available Hub label catalog.'
|
||||
required:
|
||||
- labels
|
||||
properties:
|
||||
labels:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/HubLabelInfo'
|
||||
description: Available labels, optionally filtered by type.
|
||||
|
||||
HubProfileSummary:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
|
||||
Reference in New Issue
Block a user