mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-17 04:07:31 +08:00
Compare commits
8 Commits
feat/load-
...
feat/cube3
| Author | SHA1 | Date | |
|---|---|---|---|
| e7f99168ae | |||
| 029b782936 | |||
| 81f5f84ad6 | |||
| d8635dcb39 | |||
| aeb3c77ae9 | |||
| a6c7397b71 | |||
| 871f7bc390 | |||
| 01a8783bee |
@ -1955,3 +1955,120 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
transformer_options.pop("ar_state", None)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _cube_process_logits(logits, top_p, generator):
|
||||
"""Token selection. top_p>=1 or <=0 -> greedy argmax (upstream default, deterministic)."""
|
||||
if top_p is None or top_p >= 1.0 or top_p <= 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
|
||||
remove = sorted_logits.softmax(dim=-1).cumsum(dim=-1) > top_p
|
||||
remove[..., 0] = False
|
||||
idx_remove = remove.scatter(-1, sorted_idx, remove)
|
||||
logits = logits.masked_fill(idx_remove, float("-inf"))
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
return torch.multinomial(probs, num_samples=1, generator=generator)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_cube(model, x, sigmas, extra_args=None, callback=None, disable=None, top_p=1.0):
|
||||
"""
|
||||
Autoregressive sampler for Roblox Cube3D shape GPT (DualStreamRoformer).
|
||||
|
||||
Not a diffusion sampler: the noised input `x` and `sigmas` values are ignored;
|
||||
only x's shape (batch, 1, num_tokens) is used. Generates a 1024-long sequence of VQ
|
||||
token IDs from CLIP text conditioning, with upstream's linearly-decaying CFG and
|
||||
optional top-p. Plugs into SamplerCustomAdvanced via the SamplerCube node.
|
||||
|
||||
Faithful to cube3d.inference.engine.Engine.run_gpt:
|
||||
gamma_i = cfg * (T - i) / T ; logits = (1+gamma)*cond - gamma*uncond
|
||||
fp32 weights + bf16 autocast on cuda.
|
||||
"""
|
||||
import comfy.model_management
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
guider = model.inner_model # CFGGuider
|
||||
base_model = guider.inner_model # BaseModel (Cube3D)
|
||||
cube = base_model.diffusion_model
|
||||
cfg = getattr(guider, "cfg", 3.0)
|
||||
|
||||
def get_cond(name):
|
||||
conds = guider.conds.get(name, None)
|
||||
if not conds:
|
||||
return None
|
||||
return conds[0]["model_conds"]["c_crossattn"].cond
|
||||
|
||||
pos = get_cond("positive")
|
||||
neg = get_cond("negative")
|
||||
if pos is None:
|
||||
raise ValueError("sample_cube requires positive conditioning (CLIP-L text embeds).")
|
||||
|
||||
device = x.device
|
||||
weight_dtype = base_model.get_dtype()
|
||||
T = x.shape[-1] # sequence length; latent is (batch, 1, num_tokens)
|
||||
batch = x.shape[0]
|
||||
import comfy.utils
|
||||
pos = comfy.utils.repeat_to_batch_size(pos, batch)
|
||||
if neg is not None:
|
||||
neg = comfy.utils.repeat_to_batch_size(neg, batch)
|
||||
use_cfg = (cfg is not None) and (cfg > 0.0) and (neg is not None)
|
||||
autocast_enabled = (device.type == "cuda")
|
||||
cache_dtype = torch.bfloat16 if autocast_enabled else weight_dtype
|
||||
|
||||
def add_bbox(c):
|
||||
if not getattr(cube, "use_bbox", False):
|
||||
return c
|
||||
bbox = torch.zeros((c.shape[0], 3), device=device, dtype=c.dtype)
|
||||
return torch.cat([c, cube.bbox_proj(bbox).unsqueeze(1)], dim=1)
|
||||
|
||||
# Conditioning (text_proj + bbox_proj) is computed in the model's weight dtype
|
||||
# OUTSIDE the bf16 autocast block, matching upstream cube's Engine.prepare_inputs
|
||||
# (run_clip/encode_text run in full precision). The autocast only covers the
|
||||
# autoregressive transformer forward, exactly like Engine.run_gpt.
|
||||
cond = add_bbox(cube.encode_text(pos.to(device=device, dtype=weight_dtype)))
|
||||
if use_cfg:
|
||||
ucond = add_bbox(cube.encode_text(neg.to(device=device, dtype=weight_dtype)))
|
||||
cond = torch.cat([cond, ucond], dim=0)
|
||||
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
|
||||
bos = torch.full((cond.shape[0], 1), cube.shape_bos_id, dtype=torch.long, device=device)
|
||||
embed = cube.encode_token(bos)
|
||||
Bp, input_seq_len, dim = embed.shape
|
||||
embed_buffer = torch.zeros((Bp, input_seq_len + T, dim), dtype=embed.dtype, device=device)
|
||||
embed_buffer[:, :input_seq_len, :].copy_(embed)
|
||||
|
||||
kv_cache = cube.init_kv_cache(Bp, cond.shape[1], T + 1, cache_dtype, device)
|
||||
|
||||
num_codes = cube.vocab_size - 3
|
||||
seed = extra_args.get("seed", 0)
|
||||
generator = None
|
||||
if device.type != "mps":
|
||||
generator = torch.Generator(device=device).manual_seed(int(seed))
|
||||
|
||||
output_ids = []
|
||||
for i in trange(T, disable=disable):
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
curr_pos_id = torch.tensor([i], dtype=torch.long, device=device)
|
||||
logits = cube(embed_buffer, cond, kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=(i > 0))
|
||||
logits = logits[:, 0, :num_codes]
|
||||
|
||||
if use_cfg:
|
||||
cond_logits, uncond_logits = logits.float().chunk(2, dim=0)
|
||||
gamma = cfg * (T - i) / T
|
||||
logits = (1.0 + gamma) * cond_logits - gamma * uncond_logits
|
||||
else:
|
||||
logits = logits.float()
|
||||
|
||||
next_id = _cube_process_logits(logits, top_p, generator)
|
||||
output_ids.append(next_id)
|
||||
|
||||
next_embed = cube.encode_token(next_id)
|
||||
if use_cfg:
|
||||
next_embed = torch.cat([next_embed, next_embed], dim=0)
|
||||
embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
|
||||
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[0], "sigma_hat": sigmas[0], "denoised": x})
|
||||
|
||||
# (B, T) token IDs -> (B, 1, T) to keep the channels-first 1D latent layout.
|
||||
return torch.cat(output_ids, dim=1).to(torch.float32).unsqueeze(1)
|
||||
|
||||
@ -775,6 +775,16 @@ class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_dimensions = 1
|
||||
scale_factor = 1.0188137142395404
|
||||
|
||||
class Cube3D(LatentFormat):
|
||||
# Roblox Cube3D shape "latent" is a flat sequence of VQ token IDs (one scalar per
|
||||
# position), so it maps to a channels-first 1D latent (B, 1, num_tokens), mirroring
|
||||
# Hunyuan3Dv2's (B, C, L) convention. latent_channels=1 keeps fix_empty_latent_channels
|
||||
# from truncating the token sequence. scale_factor=1.0 since IDs must pass through
|
||||
# process_latent_in/out unchanged.
|
||||
latent_channels = 1
|
||||
latent_dimensions = 1
|
||||
scale_factor = 1.0
|
||||
|
||||
class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
417
comfy/ldm/cube/gpt.py
Normal file
417
comfy/ldm/cube/gpt.py
Normal file
@ -0,0 +1,417 @@
|
||||
"""
|
||||
Native port of Roblox/cube's shape GPT (DualStreamRoformer).
|
||||
|
||||
Reference: https://github.com/Roblox/cube (cube3d/model/gpt/dual_stream_roformer.py
|
||||
and cube3d/model/transformers/*).
|
||||
|
||||
This is an autoregressive transformer over discrete VQ shape tokens, conditioned on
|
||||
CLIP text embeddings. It is NOT a diffusion model; it is driven by the dedicated
|
||||
`sample_cube` sampler (see comfy/k_diffusion/sampling.py), not KSampler.
|
||||
|
||||
The forward pass is kept faithful to upstream so token IDs match bit-for-bit:
|
||||
* rope_theta = 10000
|
||||
* per-head RMSNorm on Q and K
|
||||
* dual-stream (MM-DiT style) joint attention; last dual block is cond_pre_only
|
||||
* two separate RoPE frequency tensors (dual blocks offset cond tokens by S)
|
||||
* SwiGLU MLP, non-affine LayerNorm upcast to fp32
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Norms (faithful to cube3d/model/transformers/norm.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CubeLayerNorm(nn.Module):
|
||||
"""Non-affine LayerNorm that upcasts to fp32 then back (matches cube)."""
|
||||
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = (dim,)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
y = F.layer_norm(x.float(), self.dim, None, None, self.eps)
|
||||
return y.type_as(x)
|
||||
|
||||
|
||||
class CubeRMSNorm(nn.Module):
|
||||
"""Per-head RMSNorm with learnable weight, computed in fp32 (matches cube)."""
|
||||
|
||||
def __init__(self, dim, eps=1e-5, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x):
|
||||
xf = x.float()
|
||||
out = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return (out * self.weight).type_as(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RoPE (faithful to cube3d/model/transformers/rope.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis, curr_pos_id=None):
|
||||
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
if curr_pos_id is None:
|
||||
freqs_cis = freqs_cis[:, -x.shape[2]:].unsqueeze(1)
|
||||
else:
|
||||
freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1)
|
||||
y = torch.view_as_real(x_ * freqs_cis).flatten(3)
|
||||
return y.type_as(x)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim, t, theta=10000.0):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim))
|
||||
freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1)
|
||||
return torch.polar(torch.ones_like(freqs), freqs)
|
||||
|
||||
|
||||
def sdpa_with_rope(q, k, v, freqs_cis, attn_mask=None, curr_pos_id=None, is_causal=False):
|
||||
q = apply_rotary_emb(q, freqs_cis, curr_pos_id=curr_pos_id)
|
||||
k = apply_rotary_emb(k, freqs_cis, curr_pos_id=None)
|
||||
return F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0,
|
||||
is_causal=is_causal and attn_mask is None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KV cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Cache:
|
||||
def __init__(self, key_states, value_states):
|
||||
self.key_states = key_states
|
||||
self.value_states = value_states
|
||||
|
||||
def update(self, curr_pos_id, k, v):
|
||||
self.key_states.index_copy_(2, curr_pos_id, k)
|
||||
self.value_states.index_copy_(2, curr_pos_id, v)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared building blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SwiGLUMLP(nn.Module):
|
||||
def __init__(self, embed_dim, hidden_dim, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.gate_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.up_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.down_proj = operations.Linear(hidden_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class SelfAttentionWithRotaryEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
head_dim = embed_dim // num_heads
|
||||
self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
b, l, d = x.shape
|
||||
q, k = self.c_qk(x).chunk(2, dim=-1)
|
||||
v = self.c_v(x)
|
||||
q = q.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
k = k.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
v = v.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if kv_cache is not None:
|
||||
if not decode:
|
||||
kv_cache.key_states[:, :, :k.shape[2], :].copy_(k)
|
||||
kv_cache.value_states[:, :, :k.shape[2], :].copy_(v)
|
||||
else:
|
||||
kv_cache.update(curr_pos_id, k, v)
|
||||
k = kv_cache.key_states
|
||||
v = kv_cache.value_states
|
||||
y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask,
|
||||
curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal)
|
||||
y = y.transpose(1, 2).contiguous().view(b, l, d)
|
||||
return self.c_proj(y)
|
||||
|
||||
|
||||
class DecoderLayerWithRotaryEmbedding(nn.Module):
|
||||
"""Single-stream decoder layer (shape tokens only)."""
|
||||
|
||||
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.attn = SelfAttentionWithRotaryEmbedding(embed_dim, num_heads, bias=bias, eps=eps,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
x = x + self.attn(self.ln_1(x), freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal,
|
||||
kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dual-stream blocks (faithful to dual_stream_attention.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DismantledPreAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, query=True, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0
|
||||
self.query = query
|
||||
head_dim = embed_dim // num_heads
|
||||
if query:
|
||||
self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.c_k = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def _to_mha(self, x):
|
||||
return x.view(*x.shape[:2], self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
def forward(self, x):
|
||||
if self.query:
|
||||
q, k = self.c_qk(x).chunk(2, dim=-1)
|
||||
q = self.q_norm(self._to_mha(q))
|
||||
else:
|
||||
q = None
|
||||
k = self.c_k(x)
|
||||
k = self.k_norm(self._to_mha(k))
|
||||
v = self._to_mha(self.c_v(x))
|
||||
return (q, k, v)
|
||||
|
||||
|
||||
class DismantledPostAttention(nn.Module):
|
||||
def __init__(self, embed_dim, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.ln_3 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, a):
|
||||
x = x + self.c_proj(a)
|
||||
x = x + self.mlp(self.ln_3(x))
|
||||
return x
|
||||
|
||||
|
||||
class DualStreamAttentionWithRotaryEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.cond_pre_only = cond_pre_only
|
||||
self.pre_x = DismantledPreAttention(embed_dim, num_heads, query=True, bias=bias,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
self.pre_c = DismantledPreAttention(embed_dim, num_heads, query=not cond_pre_only, bias=bias,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
if kv_cache is None or not decode:
|
||||
qkv_c = self.pre_c(c)
|
||||
qkv_x = self.pre_x(x)
|
||||
if self.cond_pre_only:
|
||||
q = qkv_x[0]
|
||||
else:
|
||||
q = torch.cat([qkv_c[0], qkv_x[0]], dim=2)
|
||||
k = torch.cat([qkv_c[1], qkv_x[1]], dim=2)
|
||||
v = torch.cat([qkv_c[2], qkv_x[2]], dim=2)
|
||||
else:
|
||||
is_causal = False
|
||||
q, k, v = self.pre_x(x)
|
||||
|
||||
if kv_cache is not None:
|
||||
if not decode:
|
||||
kv_cache.key_states[:, :, :k.shape[2], :].copy_(k)
|
||||
kv_cache.value_states[:, :, :k.shape[2], :].copy_(v)
|
||||
else:
|
||||
kv_cache.update(curr_pos_id, k, v)
|
||||
k = kv_cache.key_states
|
||||
v = kv_cache.value_states
|
||||
|
||||
if attn_mask is not None:
|
||||
if decode:
|
||||
attn_mask = attn_mask[..., curr_pos_id, :]
|
||||
else:
|
||||
attn_mask = attn_mask[..., -q.shape[2]:, :]
|
||||
|
||||
y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask,
|
||||
curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal)
|
||||
y = y.transpose(1, 2).contiguous().view(x.shape[0], -1, x.shape[2])
|
||||
|
||||
if y.shape[1] == x.shape[1]:
|
||||
return y, None
|
||||
y_c, y_x = torch.split(y, [c.shape[1], x.shape[1]], dim=1)
|
||||
return y_x, y_c
|
||||
|
||||
|
||||
class DualStreamDecoderLayerWithRotaryEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, eps=1e-6,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.attn = DualStreamAttentionWithRotaryEmbedding(embed_dim, num_heads, cond_pre_only=cond_pre_only,
|
||||
bias=bias, dtype=dtype, device=device, operations=operations)
|
||||
self.post_1 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations)
|
||||
if not cond_pre_only:
|
||||
self.post_2 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
a_x, a_c = self.attn(
|
||||
self.ln_1(x),
|
||||
self.ln_2(c) if c is not None else None,
|
||||
freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal,
|
||||
kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode,
|
||||
)
|
||||
x = self.post_1(x, a_x)
|
||||
if a_c is not None:
|
||||
c = self.post_2(c, a_c)
|
||||
else:
|
||||
c = None
|
||||
return x, c
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DualStreamRoformer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DualStreamRoformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_layer=23,
|
||||
n_single_layer=1,
|
||||
rope_theta=10000,
|
||||
n_head=12,
|
||||
n_embd=1536,
|
||||
bias=True,
|
||||
eps=1e-6,
|
||||
shape_model_vocab_size=16384,
|
||||
shape_model_embed_dim=32,
|
||||
text_model_embed_dim=768,
|
||||
use_bbox=True,
|
||||
image_model=None, # detection key; unused
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.n_layer = n_layer
|
||||
self.n_single_layer = n_single_layer
|
||||
self.n_head = n_head
|
||||
self.n_embd = n_embd
|
||||
self.rope_theta = rope_theta
|
||||
self.head_dim = n_embd // n_head
|
||||
|
||||
self.text_proj = operations.Linear(text_model_embed_dim, n_embd, bias=bias, dtype=dtype, device=device)
|
||||
self.shape_proj = operations.Linear(shape_model_embed_dim, n_embd, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.vocab_size = shape_model_vocab_size
|
||||
self.shape_bos_id = self.vocab_size
|
||||
self.shape_eos_id = self.vocab_size + 1
|
||||
self.padding_id = self.vocab_size + 2
|
||||
self.vocab_size += 3
|
||||
|
||||
self.transformer = nn.ModuleDict(dict(
|
||||
wte=operations.Embedding(self.vocab_size, n_embd, padding_idx=self.padding_id, dtype=dtype, device=device),
|
||||
dual_blocks=nn.ModuleList([
|
||||
DualStreamDecoderLayerWithRotaryEmbedding(
|
||||
n_embd, n_head, cond_pre_only=(i == n_layer - 1), bias=bias, eps=eps,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]),
|
||||
single_blocks=nn.ModuleList([
|
||||
DecoderLayerWithRotaryEmbedding(n_embd, n_head, bias=bias, eps=eps,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(n_single_layer)
|
||||
]),
|
||||
ln_f=CubeLayerNorm(n_embd, eps=eps),
|
||||
))
|
||||
|
||||
self.lm_head = operations.Linear(n_embd, self.vocab_size, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.use_bbox = use_bbox
|
||||
if use_bbox:
|
||||
self.bbox_proj = operations.Linear(3, n_embd, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def encode_text(self, text_embed):
|
||||
return self.text_proj(text_embed)
|
||||
|
||||
def encode_token(self, tokens):
|
||||
return self.transformer.wte(tokens)
|
||||
|
||||
def init_kv_cache(self, batch_size, cond_len, max_shape_tokens, dtype, device):
|
||||
max_all = cond_len + max_shape_tokens
|
||||
kv = [
|
||||
Cache(
|
||||
torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device),
|
||||
torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device),
|
||||
)
|
||||
for _ in range(len(self.transformer.dual_blocks))
|
||||
]
|
||||
kv += [
|
||||
Cache(
|
||||
torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device),
|
||||
torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device),
|
||||
)
|
||||
for _ in range(len(self.transformer.single_blocks))
|
||||
]
|
||||
return kv
|
||||
|
||||
def forward(self, embed, cond, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
b, l = embed.shape[:2]
|
||||
s = cond.shape[1]
|
||||
device = embed.device
|
||||
|
||||
attn_mask = torch.tril(torch.ones(s + l, s + l, dtype=torch.bool, device=device))
|
||||
|
||||
position_ids = torch.arange(l, dtype=torch.long, device=device).unsqueeze(0).expand(b, -1)
|
||||
s_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta)
|
||||
|
||||
position_ids = torch.cat([
|
||||
torch.zeros([b, s], dtype=torch.long, device=device),
|
||||
position_ids,
|
||||
], dim=1)
|
||||
d_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta)
|
||||
|
||||
if kv_cache is not None and decode:
|
||||
embed = embed[:, curr_pos_id, :]
|
||||
|
||||
h = embed
|
||||
c = cond
|
||||
layer_idx = 0
|
||||
for block in self.transformer.dual_blocks:
|
||||
h, c = block(
|
||||
h, c=c, freqs_cis=d_freqs_cis, attn_mask=attn_mask, is_causal=True,
|
||||
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
||||
curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None,
|
||||
decode=decode,
|
||||
)
|
||||
layer_idx += 1
|
||||
for block in self.transformer.single_blocks:
|
||||
h = block(
|
||||
h, freqs_cis=s_freqs_cis, attn_mask=None, is_causal=True,
|
||||
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
||||
curr_pos_id=curr_pos_id, decode=decode,
|
||||
)
|
||||
layer_idx += 1
|
||||
|
||||
h = self.transformer.ln_f(h)
|
||||
return self.lm_head(h)
|
||||
379
comfy/ldm/cube/marching_cubes.py
Normal file
379
comfy/ldm/cube/marching_cubes.py
Normal file
@ -0,0 +1,379 @@
|
||||
"""Dependency-free marching cubes (classic Lorensen/Cline) in pure PyTorch.
|
||||
|
||||
Vendored so Cube3D mesh extraction needs no scikit-image. This is the same
|
||||
algorithm family as upstream cube's default NVIDIA-warp backend (warp.MarchingCubes),
|
||||
so geometry is closer to the upstream default than skimage's Lewiner fallback.
|
||||
|
||||
Output convention matches skimage.measure.marching_cubes: vertices are returned in
|
||||
array-index coordinates (axis 0, axis 1, axis 2 of the input volume), so the caller's
|
||||
`vertices / grid_size * bbox_size + bbox_min` transform applies unchanged.
|
||||
|
||||
The standard 256-entry triangle table (Paul Bourke / Cory Bloyd) is used with the
|
||||
canonical corner and edge numbering:
|
||||
|
||||
corners (x, y, z): edges (corner pairs):
|
||||
0: (0,0,0) 1: (1,0,0) 0:0-1 1:1-2 2:2-3 3:3-0
|
||||
2: (1,1,0) 3: (0,1,0) 4:4-5 5:5-6 6:6-7 7:7-4
|
||||
4: (0,0,1) 5: (1,0,1) 8:0-4 9:1-5 10:2-6 11:3-7
|
||||
6: (1,1,1) 7: (0,1,1)
|
||||
|
||||
Here x maps to volume axis 0, y to axis 1, z to axis 2.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Corner offsets in (axis0, axis1, axis2) for the 8 cube corners.
|
||||
_CORNERS = np.array([
|
||||
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0],
|
||||
[0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1],
|
||||
], dtype=np.int64)
|
||||
|
||||
# The two corner indices that each of the 12 edges connects.
|
||||
_EDGE_CORNERS = np.array([
|
||||
[0, 1], [1, 2], [2, 3], [3, 0],
|
||||
[4, 5], [5, 6], [6, 7], [7, 4],
|
||||
[0, 4], [1, 5], [2, 6], [3, 7],
|
||||
], dtype=np.int64)
|
||||
|
||||
# Standard 256 x 16 triangle table. For cube configuration `i`, lists triples of
|
||||
# edge indices forming triangles, terminated by -1.
|
||||
_TRI_TABLE = [
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 8, 3, 9, 8, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 3, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 2, 10, 0, 2, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 8, 3, 2, 10, 8, 10, 9, 8, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 11, 2, 8, 11, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 9, 0, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 11, 2, 1, 9, 11, 9, 8, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 10, 1, 11, 10, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 10, 1, 0, 8, 10, 8, 11, 10, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 9, 0, 3, 11, 9, 11, 10, 9, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 3, 0, 7, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 1, 9, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 1, 9, 4, 7, 1, 7, 3, 1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 10, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 4, 7, 3, 0, 4, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 2, 10, 9, 0, 2, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4, -1, -1, -1, -1],
|
||||
[8, 4, 7, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[11, 4, 7, 11, 2, 4, 2, 0, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 0, 1, 8, 4, 7, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1, -1, -1, -1, -1],
|
||||
[3, 10, 1, 3, 11, 10, 7, 8, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4, -1, -1, -1, -1],
|
||||
[4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3, -1, -1, -1, -1],
|
||||
[4, 7, 11, 4, 11, 9, 9, 11, 10, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 5, 4, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 5, 4, 1, 5, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 5, 4, 8, 3, 5, 3, 1, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 10, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 0, 8, 1, 2, 10, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 2, 10, 5, 4, 2, 4, 0, 2, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8, -1, -1, -1, -1],
|
||||
[9, 5, 4, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 11, 2, 0, 8, 11, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 5, 4, 0, 1, 5, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5, -1, -1, -1, -1],
|
||||
[10, 3, 11, 10, 1, 3, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10, -1, -1, -1, -1],
|
||||
[5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3, -1, -1, -1, -1],
|
||||
[5, 4, 8, 5, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 7, 8, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 3, 0, 9, 5, 3, 5, 7, 3, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 7, 8, 0, 1, 7, 1, 5, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 7, 8, 9, 5, 7, 10, 1, 2, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3, -1, -1, -1, -1],
|
||||
[8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2, -1, -1, -1, -1],
|
||||
[2, 10, 5, 2, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[7, 9, 5, 7, 8, 9, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11, -1, -1, -1, -1],
|
||||
[2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7, -1, -1, -1, -1],
|
||||
[11, 2, 1, 11, 1, 7, 7, 1, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11, -1, -1, -1, -1],
|
||||
[5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0, -1],
|
||||
[11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0, -1],
|
||||
[11, 10, 5, 7, 11, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 3, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 0, 1, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 8, 3, 1, 9, 8, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 6, 5, 2, 6, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 6, 5, 1, 2, 6, 3, 0, 8, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 6, 5, 9, 0, 6, 0, 2, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8, -1, -1, -1, -1],
|
||||
[2, 3, 11, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[11, 0, 8, 11, 2, 0, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 1, 9, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11, -1, -1, -1, -1],
|
||||
[6, 3, 11, 6, 5, 3, 5, 1, 3, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6, -1, -1, -1, -1],
|
||||
[3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9, -1, -1, -1, -1],
|
||||
[6, 5, 9, 6, 9, 11, 11, 9, 8, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 10, 6, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 3, 0, 4, 7, 3, 6, 5, 10, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 9, 0, 5, 10, 6, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4, -1, -1, -1, -1],
|
||||
[6, 1, 2, 6, 5, 1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7, -1, -1, -1, -1],
|
||||
[8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6, -1, -1, -1, -1],
|
||||
[7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9, -1],
|
||||
[3, 11, 2, 7, 8, 4, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11, -1, -1, -1, -1],
|
||||
[0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1],
|
||||
[9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6, -1],
|
||||
[8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6, -1, -1, -1, -1],
|
||||
[5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11, -1],
|
||||
[0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7, -1],
|
||||
[6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9, -1, -1, -1, -1],
|
||||
[10, 4, 9, 6, 4, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 10, 6, 4, 9, 10, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 0, 1, 10, 6, 0, 6, 4, 0, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10, -1, -1, -1, -1],
|
||||
[1, 4, 9, 1, 2, 4, 2, 6, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4, -1, -1, -1, -1],
|
||||
[0, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 3, 2, 8, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 4, 9, 10, 6, 4, 11, 2, 3, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6, -1, -1, -1, -1],
|
||||
[3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10, -1, -1, -1, -1],
|
||||
[6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1, -1],
|
||||
[9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3, -1, -1, -1, -1],
|
||||
[8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1, -1],
|
||||
[3, 11, 6, 3, 6, 0, 0, 6, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||
[6, 4, 8, 11, 6, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[7, 10, 6, 7, 8, 10, 8, 9, 10, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10, -1, -1, -1, -1],
|
||||
[10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0, -1, -1, -1, -1],
|
||||
[10, 6, 7, 10, 7, 1, 1, 7, 3, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7, -1, -1, -1, -1],
|
||||
[2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9, -1],
|
||||
[7, 8, 0, 7, 0, 6, 6, 0, 2, -1, -1, -1, -1, -1, -1, -1],
|
||||
[7, 3, 2, 6, 7, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7, -1, -1, -1, -1],
|
||||
[2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7, -1],
|
||||
[1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11, -1],
|
||||
[11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1, -1, -1, -1, -1],
|
||||
[8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6, -1],
|
||||
[0, 9, 1, 11, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0, -1, -1, -1, -1],
|
||||
[7, 11, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 0, 8, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 1, 9, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 1, 9, 8, 3, 1, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 1, 2, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 10, 3, 0, 8, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 9, 0, 2, 10, 9, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8, -1, -1, -1, -1],
|
||||
[7, 2, 3, 6, 2, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[7, 0, 8, 7, 6, 0, 6, 2, 0, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 7, 6, 2, 3, 7, 0, 1, 9, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6, -1, -1, -1, -1],
|
||||
[10, 7, 6, 10, 1, 7, 1, 3, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8, -1, -1, -1, -1],
|
||||
[0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7, -1, -1, -1, -1],
|
||||
[7, 6, 10, 7, 10, 8, 8, 10, 9, -1, -1, -1, -1, -1, -1, -1],
|
||||
[6, 8, 4, 11, 8, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 6, 11, 3, 0, 6, 0, 4, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 6, 11, 8, 4, 6, 9, 0, 1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6, -1, -1, -1, -1],
|
||||
[6, 8, 4, 6, 11, 8, 2, 10, 1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6, -1, -1, -1, -1],
|
||||
[4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9, -1, -1, -1, -1],
|
||||
[10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3, -1],
|
||||
[8, 2, 3, 8, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8, -1, -1, -1, -1],
|
||||
[1, 9, 4, 1, 4, 2, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1, -1, -1, -1, -1],
|
||||
[10, 1, 0, 10, 0, 6, 6, 0, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3, -1],
|
||||
[10, 9, 4, 6, 10, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 9, 5, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 3, 4, 9, 5, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 0, 1, 5, 4, 0, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5, -1, -1, -1, -1],
|
||||
[9, 5, 4, 10, 1, 2, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5, -1, -1, -1, -1],
|
||||
[7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2, -1, -1, -1, -1],
|
||||
[3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6, -1],
|
||||
[7, 2, 3, 7, 6, 2, 5, 4, 9, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7, -1, -1, -1, -1],
|
||||
[3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0, -1, -1, -1, -1],
|
||||
[6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8, -1],
|
||||
[9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7, -1, -1, -1, -1],
|
||||
[1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4, -1],
|
||||
[4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10, -1],
|
||||
[7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10, -1, -1, -1, -1],
|
||||
[6, 9, 5, 6, 11, 9, 11, 8, 9, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5, -1, -1, -1, -1],
|
||||
[0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11, -1, -1, -1, -1],
|
||||
[6, 11, 3, 6, 3, 5, 5, 3, 1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6, -1, -1, -1, -1],
|
||||
[0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10, -1],
|
||||
[11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5, -1],
|
||||
[6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3, -1, -1, -1, -1],
|
||||
[5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2, -1, -1, -1, -1],
|
||||
[9, 5, 6, 9, 6, 0, 0, 6, 2, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8, -1],
|
||||
[1, 5, 6, 2, 1, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6, -1],
|
||||
[10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0, -1, -1, -1, -1],
|
||||
[0, 3, 8, 5, 6, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[11, 5, 10, 7, 5, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[11, 5, 10, 11, 7, 5, 8, 3, 0, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 11, 7, 5, 10, 11, 1, 9, 0, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1, -1, -1, -1, -1],
|
||||
[11, 1, 2, 11, 7, 1, 7, 5, 1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11, -1, -1, -1, -1],
|
||||
[9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7, -1, -1, -1, -1],
|
||||
[7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2, -1],
|
||||
[2, 5, 10, 2, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5, -1, -1, -1, -1],
|
||||
[9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2, -1, -1, -1, -1],
|
||||
[9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2, -1],
|
||||
[1, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 7, 0, 7, 1, 1, 7, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 0, 3, 9, 3, 5, 5, 3, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 8, 7, 5, 9, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 8, 4, 5, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0, -1, -1, -1, -1],
|
||||
[0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5, -1, -1, -1, -1],
|
||||
[10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4, -1],
|
||||
[2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8, -1, -1, -1, -1],
|
||||
[0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11, -1],
|
||||
[0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5, -1],
|
||||
[9, 4, 5, 2, 11, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4, -1, -1, -1, -1],
|
||||
[5, 10, 2, 5, 2, 4, 4, 2, 0, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9, -1],
|
||||
[5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2, -1, -1, -1, -1],
|
||||
[8, 4, 5, 8, 5, 3, 3, 5, 1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 4, 5, 1, 0, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5, -1, -1, -1, -1],
|
||||
[9, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 11, 7, 4, 9, 11, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11, -1, -1, -1, -1],
|
||||
[1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11, -1, -1, -1, -1],
|
||||
[3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4, -1],
|
||||
[4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2, -1, -1, -1, -1],
|
||||
[9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3, -1],
|
||||
[11, 7, 4, 11, 4, 2, 2, 4, 0, -1, -1, -1, -1, -1, -1, -1],
|
||||
[11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4, -1, -1, -1, -1],
|
||||
[2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9, -1, -1, -1, -1],
|
||||
[9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7, -1],
|
||||
[3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10, -1],
|
||||
[1, 10, 2, 8, 7, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 9, 1, 4, 1, 7, 7, 1, 3, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1, -1, -1, -1, -1],
|
||||
[4, 0, 3, 7, 4, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 8, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 0, 9, 3, 9, 11, 11, 9, 10, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 1, 10, 0, 10, 8, 8, 10, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 1, 10, 11, 3, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 11, 1, 11, 9, 9, 11, 8, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9, -1, -1, -1, -1],
|
||||
[0, 2, 11, 8, 0, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 3, 8, 2, 8, 10, 10, 8, 9, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 10, 2, 0, 9, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8, -1, -1, -1, -1],
|
||||
[1, 10, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 3, 8, 9, 1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 9, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 3, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def marching_cubes(volume: torch.Tensor, level: float = 0.0):
|
||||
"""Extract an isosurface from a 3D scalar field.
|
||||
|
||||
Args:
|
||||
volume: (D, H, W) float tensor. Inside is where ``volume < level`` (matches
|
||||
the classic Lorensen convention and skimage's ``method="lorensen"``).
|
||||
level: isosurface threshold.
|
||||
|
||||
Returns:
|
||||
(vertices, faces): numpy arrays. ``vertices`` are float32 (N, 3) in array-index
|
||||
coordinates (axis0, axis1, axis2); ``faces`` are int64 (M, 3).
|
||||
"""
|
||||
assert volume.ndim == 3, "volume must be (D, H, W)"
|
||||
device = volume.device
|
||||
vol = volume.float()
|
||||
|
||||
tri_table = torch.tensor(_TRI_TABLE, dtype=torch.long, device=device) # (256, 16)
|
||||
edge_corners = torch.tensor(_EDGE_CORNERS, dtype=torch.long, device=device) # (12, 2)
|
||||
corners = torch.tensor(_CORNERS, dtype=torch.float32, device=device) # (8, 3)
|
||||
|
||||
# Corner scalar values for every cell, shape (nc0, nc1, nc2, 8).
|
||||
nc0, nc1, nc2 = vol.shape[0] - 1, vol.shape[1] - 1, vol.shape[2] - 1
|
||||
if nc0 <= 0 or nc1 <= 0 or nc2 <= 0:
|
||||
return (np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.int64))
|
||||
|
||||
corner_vals = torch.empty((nc0, nc1, nc2, 8), dtype=torch.float32, device=device)
|
||||
for k in range(8):
|
||||
o0, o1, o2 = _CORNERS[k]
|
||||
corner_vals[..., k] = vol[o0:o0 + nc0, o1:o1 + nc1, o2:o2 + nc2]
|
||||
|
||||
# Cube configuration index: bit k set when corner k is inside (val < level).
|
||||
inside = (corner_vals < level)
|
||||
bits = torch.tensor([1 << k for k in range(8)], dtype=torch.long, device=device)
|
||||
cube_index = (inside.long() * bits).sum(dim=-1) # (nc0, nc1, nc2)
|
||||
|
||||
# Cells that actually intersect the surface.
|
||||
active = (cube_index > 0) & (cube_index < 255)
|
||||
if not active.any():
|
||||
return (np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.int64))
|
||||
|
||||
idx0, idx1, idx2 = torch.where(active) # (Nactive,)
|
||||
cidx = cube_index[idx0, idx1, idx2] # (Nactive,)
|
||||
cell_origin = torch.stack([idx0, idx1, idx2], dim=1).float() # (Nactive, 3)
|
||||
cell_vals = corner_vals[idx0, idx1, idx2] # (Nactive, 8)
|
||||
|
||||
tris = tri_table[cidx] # (Nactive, 16)
|
||||
|
||||
# Each row holds up to 5 triangles (15 edge entries). Expand to (Nactive, 5, 3).
|
||||
tri_edges = tris[:, :15].reshape(-1, 5, 3) # edge indices, -1 = unused
|
||||
valid_tri = tri_edges[..., 0] >= 0 # (Nactive, 5)
|
||||
|
||||
cell_idx = torch.arange(cell_origin.shape[0], device=device).unsqueeze(1).expand(-1, 5)
|
||||
cell_idx = cell_idx[valid_tri] # (T,)
|
||||
edges = tri_edges[valid_tri] # (T, 3) edge index per triangle corner
|
||||
|
||||
# Interpolate a vertex on each referenced edge.
|
||||
e_flat = edges.reshape(-1) # (T*3,)
|
||||
cell_for_vert = cell_idx.unsqueeze(1).expand(-1, 3).reshape(-1) # (T*3,)
|
||||
|
||||
ca = edge_corners[e_flat, 0] # (T*3,) corner index a
|
||||
cb = edge_corners[e_flat, 1] # corner index b
|
||||
va = cell_vals[cell_for_vert, ca] # scalar at corner a
|
||||
vb = cell_vals[cell_for_vert, cb]
|
||||
pa = cell_origin[cell_for_vert] + corners[ca] # position of corner a (index space)
|
||||
pb = cell_origin[cell_for_vert] + corners[cb]
|
||||
|
||||
denom = (vb - va)
|
||||
t = torch.where(denom.abs() > 1e-12, (level - va) / denom, torch.zeros_like(denom))
|
||||
t = t.clamp(0.0, 1.0).unsqueeze(1)
|
||||
verts = pa + t * (pb - pa) # (T*3, 3) one vertex per triangle corner
|
||||
|
||||
# Weld shared vertices: a grid edge shared by adjacent cells interpolates to the exact
|
||||
# same position (same corner values/positions), so exact dedup yields a clean indexed
|
||||
# mesh like skimage/warp (one vertex per active edge).
|
||||
uniq, inverse = torch.unique(verts, dim=0, return_inverse=True)
|
||||
faces = inverse.reshape(-1, 3)
|
||||
|
||||
return (uniq.cpu().numpy().astype(np.float32), faces.cpu().numpy().astype(np.int64))
|
||||
364
comfy/ldm/cube/vae.py
Normal file
364
comfy/ldm/cube/vae.py
Normal file
@ -0,0 +1,364 @@
|
||||
"""
|
||||
Native port of Roblox/cube's shape tokenizer decode path (OneDAutoEncoder).
|
||||
|
||||
Reference: https://github.com/Roblox/cube (cube3d/model/autoencoder/*).
|
||||
|
||||
Only the DECODE path is ported (token IDs -> latents -> occupancy grid -> mesh);
|
||||
the point-cloud encoder is not needed for text-to-3D generation. Encoder weights in
|
||||
the checkpoint are loaded with strict=False and ignored.
|
||||
|
||||
Module/parameter names mirror upstream so the checkpoint loads directly:
|
||||
embedder.weight
|
||||
bottleneck.block.{codebook, cb_weight, cb_bias, c_in, c_x, c_out, ...}
|
||||
decoder.{positional_encodings, blocks.N...}
|
||||
occupancy_decoder.{query_in, attn_out, ln_f, c_head}
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Norms
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CubeLayerNorm(nn.Module):
|
||||
"""LayerNorm upcasting to fp32. affine=False by default (no params)."""
|
||||
|
||||
def __init__(self, dim, eps=1e-6, elementwise_affine=False, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.dim = (dim,)
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
|
||||
self.bias = nn.Parameter(torch.zeros(dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x):
|
||||
w = self.weight.float() if self.weight is not None else None
|
||||
b = self.bias.float() if self.bias is not None else None
|
||||
y = F.layer_norm(x.float(), self.dim, w, b, self.eps)
|
||||
return y.type_as(x)
|
||||
|
||||
|
||||
class CubeRMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5, elementwise_affine=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.register_buffer("weight", torch.ones(dim), persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
xf = x.float()
|
||||
out = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return (out * self.weight.float()).type_as(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fourier embedder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PhaseModulatedFourierEmbedder(nn.Module):
|
||||
def __init__(self, num_freqs, input_dim=3, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty(input_dim, num_freqs, dtype=dtype, device=device))
|
||||
carrier = (num_freqs / 8) ** torch.linspace(1, 0, num_freqs)
|
||||
carrier = (carrier + torch.linspace(0, 1, num_freqs)) * 2 * math.pi
|
||||
self.register_buffer("carrier", carrier, persistent=False)
|
||||
self.out_dim = input_dim * (num_freqs * 2 + 1)
|
||||
|
||||
def forward(self, x):
|
||||
m = x.float().unsqueeze(-1)
|
||||
w = self.weight.float()
|
||||
carrier = self.carrier.float()
|
||||
fm = (m * w).view(*x.shape[:-1], -1)
|
||||
pm = (m * 0.5 * math.pi + carrier).view(*x.shape[:-1], -1)
|
||||
return torch.cat([x, fm.cos() + pm.cos(), fm.sin() + pm.sin()], dim=-1).type_as(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attention building blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, embed_dim, hidden_dim, bias=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.up_proj = ops.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.down_proj = ops.Linear(hidden_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.act_fn = nn.GELU(approximate="none")
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
head_dim = embed_dim // num_heads
|
||||
self.c_qk = ops.Linear(embed_dim, 2 * embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_v = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_proj = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, attn_mask=None, is_causal=False):
|
||||
b, l, d = x.shape
|
||||
q, k = self.c_qk(x).chunk(2, dim=-1)
|
||||
v = self.c_v(x)
|
||||
q = self.q_norm(q.view(b, l, self.num_heads, -1).transpose(1, 2))
|
||||
k = self.k_norm(k.view(b, l, self.num_heads, -1).transpose(1, 2))
|
||||
v = v.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0,
|
||||
is_causal=is_causal and attn_mask is None)
|
||||
y = y.transpose(1, 2).contiguous().view(b, l, d)
|
||||
return self.c_proj(y)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, q_dim=None, kv_dim=None, bias=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0
|
||||
q_dim = q_dim or embed_dim
|
||||
kv_dim = kv_dim or embed_dim
|
||||
self.c_q = ops.Linear(q_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_k = ops.Linear(kv_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_v = ops.Linear(kv_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_proj = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
q, k, v = self.c_q(x), self.c_k(c), self.c_v(c)
|
||||
b, l, d = q.shape
|
||||
s = k.shape[1]
|
||||
q = q.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
k = k.view(b, s, self.num_heads, -1).transpose(1, 2)
|
||||
v = v.view(b, s, self.num_heads, -1).transpose(1, 2)
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
|
||||
y = y.transpose(1, 2).contiguous().view(b, l, d)
|
||||
return self.c_proj(y)
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.attn = SelfAttention(embed_dim, num_heads, bias=bias, eps=eps, dtype=dtype, device=device)
|
||||
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.mlp = MLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, attn_mask=None, is_causal=False):
|
||||
x = x + self.attn(self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class EncoderCrossAttentionLayer(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, q_dim=None, kv_dim=None, bias=True, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
q_dim = q_dim or embed_dim
|
||||
kv_dim = kv_dim or embed_dim
|
||||
self.attn = CrossAttention(embed_dim, num_heads, q_dim=q_dim, kv_dim=kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.ln_1 = CubeLayerNorm(q_dim, eps=eps)
|
||||
self.ln_2 = CubeLayerNorm(kv_dim, eps=eps)
|
||||
self.ln_f = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.mlp = MLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
x = x + self.attn(self.ln_1(x), self.ln_2(c), attn_mask=attn_mask)
|
||||
x = x + self.mlp(self.ln_f(x))
|
||||
return x
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim, embed_dim, bias=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.in_layer = ops.Linear(in_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Spherical VQ (decode-only parts)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SphericalVectorQuantizer(nn.Module):
|
||||
def __init__(self, embed_dim, num_codes, width=None, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.num_codes = num_codes
|
||||
self.codebook = ops.Embedding(num_codes, embed_dim, dtype=dtype, device=device)
|
||||
width = width or embed_dim
|
||||
if width != embed_dim:
|
||||
self.c_in = ops.Linear(width, embed_dim, dtype=dtype, device=device)
|
||||
self.c_x = ops.Linear(width, embed_dim, dtype=dtype, device=device)
|
||||
self.c_out = ops.Linear(embed_dim, width, dtype=dtype, device=device)
|
||||
else:
|
||||
self.c_in = self.c_out = self.c_x = nn.Identity()
|
||||
self.norm = CubeRMSNorm(embed_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
# "kl" codebook regularization (released config)
|
||||
self.cb_weight = nn.Parameter(torch.ones([embed_dim], dtype=dtype, device=device))
|
||||
self.cb_bias = nn.Parameter(torch.zeros([embed_dim], dtype=dtype, device=device))
|
||||
|
||||
def cb_norm(self, x):
|
||||
return x * self.cb_weight + self.cb_bias
|
||||
|
||||
def get_codebook(self):
|
||||
return self.norm(self.cb_norm(self.codebook.weight))
|
||||
|
||||
def lookup_codebook(self, q):
|
||||
z_q = F.embedding(q, self.get_codebook())
|
||||
return self.c_out(z_q)
|
||||
|
||||
|
||||
class OneDBottleNeck(nn.Module):
|
||||
def __init__(self, block):
|
||||
super().__init__()
|
||||
self.block = block
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decoders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OneDDecoder(nn.Module):
|
||||
def __init__(self, num_latents, width, num_heads, num_layers, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.register_buffer("query", torch.empty([0, width]), persistent=False)
|
||||
self.positional_encodings = nn.Parameter(torch.empty(num_latents, width, dtype=dtype, device=device))
|
||||
self.blocks = nn.ModuleList([
|
||||
EncoderLayer(width, num_heads, eps=eps, dtype=dtype, device=device)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, z):
|
||||
h = z + self.positional_encodings[:z.shape[1]].unsqueeze(0).to(z.dtype)
|
||||
for block in self.blocks:
|
||||
h = block(h)
|
||||
return h
|
||||
|
||||
|
||||
class OneDOccupancyDecoder(nn.Module):
|
||||
def __init__(self, embedder, out_features, width, num_heads, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.embedder = embedder
|
||||
self.query_in = MLPEmbedder(embedder.out_dim, width, dtype=dtype, device=device)
|
||||
self.attn_out = EncoderCrossAttentionLayer(width, num_heads, dtype=dtype, device=device)
|
||||
self.ln_f = CubeLayerNorm(width, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||
self.c_head = ops.Linear(width, out_features, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, queries, latents):
|
||||
x = self.query_in(self.embedder(queries))
|
||||
x = self.attn_out(x, latents)
|
||||
return self.c_head(self.ln_f(x))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level shape VAE
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij"):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = np.exp2(resolution_base)
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
xs, ys, zs = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1).reshape(-1, 3)
|
||||
grid_size = [int(num_cells) + 1] * 3
|
||||
return xyz, grid_size, length
|
||||
|
||||
|
||||
class CubeShapeVAE(nn.Module):
|
||||
"""Decode-only OneDAutoEncoder. Encoder weights load with strict=False (ignored)."""
|
||||
|
||||
# Fixed query bounds for the occupancy grid (upstream default).
|
||||
decode_bounds = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05)
|
||||
|
||||
def __init__(self, num_encoder_latents=1024, embed_dim=32, width=768, num_heads=12,
|
||||
num_freqs=128, num_decoder_layers=24, num_codes=16384, out_dim=1, eps=1e-6,
|
||||
dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.cfg_num_encoder_latents = num_encoder_latents
|
||||
self.cfg_num_codes = num_codes
|
||||
self.embedder = PhaseModulatedFourierEmbedder(num_freqs=num_freqs, input_dim=3, dtype=dtype, device=device)
|
||||
self.bottleneck = OneDBottleNeck(
|
||||
SphericalVectorQuantizer(embed_dim, num_codes, width, dtype=dtype, device=device)
|
||||
)
|
||||
self.decoder = OneDDecoder(num_encoder_latents, width, num_heads, num_decoder_layers,
|
||||
eps=eps, dtype=dtype, device=device)
|
||||
self.occupancy_decoder = OneDOccupancyDecoder(self.embedder, out_dim, width, num_heads,
|
||||
eps=eps, dtype=dtype, device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, samples, resolution_base=8.0, chunk_size=100_000, **kwargs):
|
||||
"""Token IDs -> occupancy grid logits. Entry point for comfy.sd.VAE.decode, which
|
||||
manages model loading/device/dtype. `samples` arrive as (B, 1, num_tokens) in the
|
||||
VAE working dtype on the load device. VAE.decode applies a trailing movedim(1, -1),
|
||||
so pre-invert it here to hand the node grid logits as (B, gx, gy, gz)."""
|
||||
ids = samples.reshape(samples.shape[0], -1)[:, :self.cfg_num_encoder_latents]
|
||||
ids = ids.round().long().clamp(0, self.cfg_num_codes - 1)
|
||||
latents = self.decode_indices(ids)
|
||||
grid_logits, _, _, _ = self.extract_geometry(
|
||||
latents, bounds=self.decode_bounds, resolution_base=resolution_base, chunk_size=chunk_size)
|
||||
return grid_logits.movedim(-1, 1)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_indices(self, shape_ids):
|
||||
z_q = self.bottleneck.block.lookup_codebook(shape_ids)
|
||||
return self.decoder(z_q)
|
||||
|
||||
@torch.no_grad()
|
||||
def query(self, queries, latents):
|
||||
return self.occupancy_decoder(queries, latents).squeeze(-1)
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_geometry(self, latents, bounds=(-1.05, -1.05, -1.05, 1.05, 1.05, 1.05),
|
||||
resolution_base=8.0, chunk_size=100_000):
|
||||
bbox_min = np.array(bounds[0:3])
|
||||
bbox_max = np.array(bounds[3:6])
|
||||
bbox_size = bbox_max - bbox_min
|
||||
|
||||
xyz, grid_size, _ = generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij")
|
||||
xyz = torch.from_numpy(xyz)
|
||||
batch_size = latents.shape[0]
|
||||
batch_logits = []
|
||||
for start in range(0, xyz.shape[0], chunk_size):
|
||||
queries = xyz[start:start + chunk_size, :]
|
||||
n = queries.shape[0]
|
||||
if start > 0 and n < chunk_size:
|
||||
queries = F.pad(queries, [0, 0, 0, chunk_size - n])
|
||||
bq = queries.unsqueeze(0).expand(batch_size, -1, -1).to(latents)
|
||||
batch_logits.append(self.query(bq, latents)[:, :n])
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1).detach().view(
|
||||
batch_size, grid_size[0], grid_size[1], grid_size[2]).float()
|
||||
return grid_logits, grid_size, bbox_size, bbox_min
|
||||
|
||||
|
||||
def grid_logits_to_mesh(grid_logit, grid_size, bbox_size, bbox_min, level=0.0):
|
||||
"""Occupancy-logit grid -> mesh, using the vendored dependency-free marching cubes
|
||||
(classic Lorensen, same family as upstream cube's default warp backend). Vertices are
|
||||
rescaled from grid-index space into the bbox, matching upstream's transform."""
|
||||
from comfy.ldm.cube.marching_cubes import marching_cubes
|
||||
vertices, faces = marching_cubes(grid_logit, level)
|
||||
vertices = vertices / np.array(grid_size) * bbox_size + bbox_min
|
||||
# The vendored Lorensen table already emits outward-facing winding for this
|
||||
# occupancy convention, so (unlike the upstream skimage path) no face flip is needed.
|
||||
return vertices.astype(np.float32), np.ascontiguousarray(faces)
|
||||
@ -44,6 +44,7 @@ import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.wan.ar_model
|
||||
import comfy.ldm.cube.gpt
|
||||
import comfy.ldm.wan.model_wandancer
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.triposplat.model
|
||||
@ -1903,6 +1904,26 @@ class Hunyuan3Dv2(BaseModel):
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class Cube3D(BaseModel):
|
||||
"""Roblox Cube3D shape GPT (autoregressive). Generation goes through the
|
||||
dedicated `cube` sampler (SamplerCustomAdvanced), never KSampler/apply_model."""
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cube.gpt.DualStreamRoformer)
|
||||
|
||||
def _apply_model(self, *args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"Cube3D is an autoregressive token model. Use the 'cube' sampler "
|
||||
"(SamplerCube + SamplerCustomAdvanced), not KSampler."
|
||||
)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class Hunyuan3Dv2_1(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)
|
||||
|
||||
@ -654,6 +654,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}shape_proj.weight'.format(key_prefix) in state_dict_keys and '{}lm_head.weight'.format(key_prefix) in state_dict_keys: # Roblox Cube3D shape GPT
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cube3d"
|
||||
n_embd = state_dict['{}transformer.wte.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["n_embd"] = n_embd
|
||||
dit_config["shape_model_vocab_size"] = state_dict['{}transformer.wte.weight'.format(key_prefix)].shape[0] - 3
|
||||
dit_config["n_layer"] = count_blocks(state_dict_keys, '{}transformer.dual_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["n_single_layer"] = count_blocks(state_dict_keys, '{}transformer.single_blocks.'.format(key_prefix) + '{}.')
|
||||
head_dim = state_dict['{}transformer.dual_blocks.0.attn.pre_x.q_norm.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["n_head"] = n_embd // head_dim
|
||||
dit_config["shape_model_embed_dim"] = state_dict['{}shape_proj.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["text_model_embed_dim"] = state_dict['{}text_proj.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["use_bbox"] = '{}bbox_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["bias"] = '{}text_proj.bias'.format(key_prefix) in state_dict_keys
|
||||
dit_config["rope_theta"] = 10000 # not stored in the state dict; upstream's fixed constant
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
|
||||
dit_config = {}
|
||||
|
||||
34
comfy/sd.py
34
comfy/sd.py
@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.cube.vae
|
||||
import comfy.ldm.triposplat.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.cogvideo.vae
|
||||
@ -777,6 +778,39 @@ class VAE:
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
# Roblox Cube3D shape tokenizer (OneDAutoEncoder, decode-only)
|
||||
elif "bottleneck.block.codebook.weight" in sd:
|
||||
self.latent_dim = 1
|
||||
# The VQ bottleneck (get_codebook/lookup_codebook) reads raw parameters
|
||||
# outside any hooked forward, so the streaming-offload cast hooks can't
|
||||
# relocate them; the model must be fully resident to decode. This is a
|
||||
# correctness requirement, declared via the standard flag (like the audio
|
||||
# VAEs) rather than managed manually in the node.
|
||||
self.disable_offload = True
|
||||
embed_dim = sd["bottleneck.block.codebook.weight"].shape[1]
|
||||
num_codes = sd["bottleneck.block.codebook.weight"].shape[0]
|
||||
width = sd["bottleneck.block.c_out.weight"].shape[0]
|
||||
num_encoder_latents = sd["decoder.positional_encodings"].shape[0]
|
||||
head_dim = sd["decoder.blocks.0.attn.q_norm.weight"].shape[0]
|
||||
num_heads = width // head_dim
|
||||
num_freqs = sd["embedder.weight"].shape[1]
|
||||
num_decoder_layers = len({k.split(".")[2] for k in sd if k.startswith("decoder.blocks.")})
|
||||
self.first_stage_model = comfy.ldm.cube.vae.CubeShapeVAE(
|
||||
num_encoder_latents=num_encoder_latents, embed_dim=embed_dim, width=width,
|
||||
num_heads=num_heads, num_freqs=num_freqs, num_decoder_layers=num_decoder_layers,
|
||||
num_codes=num_codes,
|
||||
)
|
||||
# Decode goes through the managed comfy.sd.VAE.decode path; the grid logits
|
||||
# are float32 regardless of weight dtype, so keep process_output identity
|
||||
# (the default clamps to [0, 1] in-place and would destroy the isosurface).
|
||||
self.process_output = lambda image: image
|
||||
self.process_input = lambda image: image
|
||||
# shape is the token-ID latent (B, 1, num_tokens); size by num_tokens.
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[-1] * 768) * model_management.dtype_size(dtype)
|
||||
# fp32-only (unlike most VAEs that allow fp16/bf16): the VQ codebook lookup
|
||||
# and occupancy-grid query must run in fp32 to match upstream and keep the
|
||||
# isosurface stable.
|
||||
self.working_dtypes = [torch.float32]
|
||||
|
||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
||||
|
||||
@ -1550,6 +1550,32 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||
|
||||
|
||||
class Cube3D(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "cube3d",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {}
|
||||
|
||||
latent_format = latent_formats.Cube3D
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
# Upstream keeps fp32 weights and uses bf16 autocast during the forward pass
|
||||
# (see sample_cube). Prefer fp32 weights for parity; bf16 is the low-VRAM fallback.
|
||||
supported_inference_dtypes = [torch.float32, torch.bfloat16]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.Cube3D(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
# No bundled text encoder: the cube checkpoint is GPT-only. The graph wires a
|
||||
# standard CLIPLoader(clip-l)/CLIPTextEncode, so there is no clip_target to build.
|
||||
return None
|
||||
|
||||
class TripoSplat(supported_models_base.BASE):
|
||||
# Image -> 3D gaussian splat flow denoiser
|
||||
unet_config = {
|
||||
@ -2292,6 +2318,7 @@ models = [
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
Hunyuan3Dv2_1,
|
||||
Cube3D,
|
||||
TripoSplat,
|
||||
HiDream,
|
||||
HiDreamO1,
|
||||
|
||||
156
comfy_extras/nodes_cube.py
Normal file
156
comfy_extras/nodes_cube.py
Normal file
@ -0,0 +1,156 @@
|
||||
"""
|
||||
Nodes for native Roblox Cube3D text-to-3D support.
|
||||
|
||||
Graph:
|
||||
CLIPLoader(clip-l) -> CLIPTextEncode -> CONDITIONING
|
||||
UNETLoader(shape_gpt) -> MODEL --\
|
||||
VAELoader(shape_tokenizer) -> VAE -> CubeCodebookPatch -> MODEL
|
||||
CFGGuider(MODEL, pos, neg, cfg) + SamplerCube + (trivial sigmas) + EmptyCubeLatent
|
||||
-> SamplerCustomAdvanced -> LATENT (token IDs)
|
||||
VAEDecodeCube(VAE, LATENT) -> MESH -> SaveGLB
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.ldm.cube.vae
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
from comfy_extras.nodes_save_3d import pack_variable_mesh_batch
|
||||
|
||||
|
||||
class EmptyCubeLatent(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyCubeLatent",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Int.Input("num_tokens", default=1024, min=1, max=8192,
|
||||
tooltip="Shape token sequence length. Must match the tokenizer "
|
||||
"(1024 for cube3d-v0.5, 512 for v0.1)."),
|
||||
IO.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[IO.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, num_tokens, batch_size) -> IO.NodeOutput:
|
||||
# Channels-first 1D latent (B, 1, num_tokens), mirroring Hunyuan3Dv2's (B, C, L)
|
||||
# convention (latent_channels=1). The sampler only uses the sequence length.
|
||||
latent = torch.zeros([batch_size, 1, num_tokens], device=comfy.model_management.intermediate_device())
|
||||
return IO.NodeOutput({"samples": latent, "type": "cube_tokens"})
|
||||
|
||||
|
||||
class CubeCodebookPatch(IO.ComfyNode):
|
||||
"""Inject the projected VQ codebook into the GPT token-embedding table.
|
||||
|
||||
Upstream copies shape_proj(tokenizer.codebook) into wte.weight[:num_codes] at load
|
||||
time; without it generation is garbage. Done here as a ModelPatcher object patch so
|
||||
it composes with normal model loading/offload."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="CubeCodebookPatch",
|
||||
display_name="Cube Codebook Patch",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[IO.Model.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, vae) -> IO.NodeOutput:
|
||||
gpt = model.get_model_object("diffusion_model")
|
||||
codebook = vae.first_stage_model.bottleneck.block.get_codebook() # (num_codes, embed_dim) fp32
|
||||
w = gpt.shape_proj.weight
|
||||
proj = gpt.shape_proj(codebook.to(device=w.device, dtype=w.dtype)) # (num_codes, n_embd)
|
||||
|
||||
old = model.get_model_object("diffusion_model.transformer.wte.weight")
|
||||
new = old.clone()
|
||||
new[:proj.shape[0]] = proj.to(device=new.device, dtype=new.dtype)
|
||||
|
||||
m = model.clone()
|
||||
m.add_object_patch("diffusion_model.transformer.wte.weight",
|
||||
torch.nn.Parameter(new, requires_grad=False))
|
||||
return IO.NodeOutput(m)
|
||||
|
||||
|
||||
class SamplerCube(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SamplerCube",
|
||||
display_name="Sampler Cube (autoregressive)",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
IO.Float.Input("top_p", default=1.0, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="1.0 = deterministic greedy (upstream default). "
|
||||
"<1.0 enables nucleus sampling."),
|
||||
],
|
||||
outputs=[IO.Sampler.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, top_p) -> IO.NodeOutput:
|
||||
return IO.NodeOutput(comfy.samplers.ksampler("cube", {"top_p": top_p}))
|
||||
|
||||
|
||||
class VAEDecodeCube(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeCube",
|
||||
display_name="VAE Decode Cube (3D)",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Float.Input("resolution_base", default=8.0, min=4.0, max=10.0, step=0.5,
|
||||
tooltip="Grid cells per axis = 2^resolution_base. 8.0 matches "
|
||||
"upstream default (257^3 grid)."),
|
||||
IO.Int.Input("chunk_size", default=100000, min=1000, max=2000000, advanced=True),
|
||||
],
|
||||
outputs=[IO.Mesh.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, samples, resolution_base, chunk_size) -> IO.NodeOutput:
|
||||
# Managed decode: comfy.sd.VAE.decode handles model loading + device/dtype and
|
||||
# returns the occupancy grid logits (B, gx, gy, gz). Marching cubes runs here.
|
||||
grid = vae.decode(samples["samples"],
|
||||
vae_options={"resolution_base": resolution_base, "chunk_size": chunk_size})
|
||||
|
||||
bounds = vae.first_stage_model.decode_bounds
|
||||
bbox_min = np.array(bounds[0:3])
|
||||
bbox_size = np.array(bounds[3:6]) - bbox_min
|
||||
grid_size = list(grid.shape[1:])
|
||||
|
||||
verts_list, faces_list = [], []
|
||||
for i in range(grid.shape[0]):
|
||||
v, f = comfy.ldm.cube.vae.grid_logits_to_mesh(grid[i], grid_size, bbox_size, bbox_min)
|
||||
verts_list.append(torch.from_numpy(v))
|
||||
faces_list.append(torch.from_numpy(f.astype(np.int64)))
|
||||
|
||||
mesh = pack_variable_mesh_batch(verts_list, faces_list)
|
||||
return IO.NodeOutput(mesh)
|
||||
|
||||
|
||||
class CubeExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
EmptyCubeLatent,
|
||||
CubeCodebookPatch,
|
||||
SamplerCube,
|
||||
VAEDecodeCube,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> CubeExtension:
|
||||
return CubeExtension()
|
||||
@ -317,74 +317,11 @@ class PreviewPointCloud(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
MESH_EXTENSIONS = {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
|
||||
|
||||
|
||||
class Load3DAdvanced(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
base_path = Path(folder_paths.get_input_directory())
|
||||
|
||||
files = [
|
||||
normalize_path(str(file_path.relative_to(base_path)))
|
||||
for file_path in input_path.rglob("*")
|
||||
if file_path.suffix.lower() in MESH_EXTENSIONS
|
||||
]
|
||||
return IO.Schema(
|
||||
node_id="Load3DAdvanced",
|
||||
display_name="Load 3D (Advanced)",
|
||||
category="3d",
|
||||
search_aliases=[
|
||||
"load mesh",
|
||||
"load gltf",
|
||||
"load glb",
|
||||
"load obj",
|
||||
"load fbx",
|
||||
"load stl",
|
||||
],
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
IO.Combo.Input("model_file", options=["none"] + sorted(files), upload=IO.UploadType.model),
|
||||
IO.Load3D.Input("viewport_state"),
|
||||
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DAny.Output(display_name="model_3d"),
|
||||
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
|
||||
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||
IO.Int.Output(display_name="width"),
|
||||
IO.Int.Output(display_name="height"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, model_file, **kwargs) -> bool | str:
|
||||
if not model_file or model_file == "none":
|
||||
return True
|
||||
if not folder_paths.exists_annotated_filepath(model_file):
|
||||
return f"Invalid 3D model file: {model_file}"
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_file, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
|
||||
file_3d = None
|
||||
if model_file and model_file != "none":
|
||||
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
|
||||
model_3d_info = viewport_state.get('model_3d_info', [])
|
||||
return IO.NodeOutput(file_3d, model_3d_info, viewport_state['camera_info'], width, height)
|
||||
|
||||
|
||||
class Load3DExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
Load3D,
|
||||
Load3DAdvanced,
|
||||
Preview3D,
|
||||
Preview3DAdvanced,
|
||||
PreviewGaussianSplat,
|
||||
|
||||
1
nodes.py
1
nodes.py
@ -2433,6 +2433,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_kandinsky5.py",
|
||||
"nodes_wanmove.py",
|
||||
"nodes_ar_video.py",
|
||||
"nodes_cube.py",
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
"nodes_glsl.py",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.45.15
|
||||
comfyui-workflow-templates==0.9.98
|
||||
comfyui-embedded-docs==0.5.4
|
||||
comfyui-embedded-docs==0.5.3
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
Reference in New Issue
Block a user