Compare commits

..

8 Commits

Author SHA1 Message Date
e7f99168ae Cube3D: document convention deviations + drop unused VAE flag (review aid)
- Remove the unused self.cube3d VAE flag (set but never read).
- Comment why VAE working_dtypes is fp32-only (VQ lookup + occupancy query
  parity), unlike most VAEs that allow fp16/bf16.
- Comment why Cube3D.clip_target() returns None (GPT-only checkpoint; graph
  wires a standard CLIPLoader/CLIPTextEncode).
- Note rope_theta=10000 is upstream's fixed constant, not in the state dict.

No behaviour change; comments/cleanup only.

Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096
Co-authored-by: Amp <amp@ampcode.com>
2026-06-14 23:58:14 -07:00
029b782936 Cube3D: fix mesh winding for vendored marching cubes
The vendored Lorensen table emits the opposite base winding from skimage, so
the upstream-style faces[:, [2,1,0]] flip produced inward-facing normals
(negative mesh volume). Drop the flip so normals point outward (positive
volume), matching the upstream output orientation.

Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096
Co-authored-by: Amp <amp@ampcode.com>
2026-06-14 23:48:03 -07:00
81f5f84ad6 Cube3D: vendor dependency-free marching cubes, drop scikit-image
scikit-image was added solely for Cube3D's VAEDecodeCube. Replace it with a
vendored, vectorized pure-PyTorch marching cubes (classic Lorensen tables) in
comfy/ldm/cube/marching_cubes.py. This is the same algorithm family as upstream
cube's default warp.MarchingCubes backend, so geometry is closer to upstream's
default than skimage's Lewiner fallback was.

Validated against skimage method='lorensen': identical face count and surface
(nearest-neighbour distance ~3.8e-6, float precision) on sphere/torus fields.
Vertices are welded (shared grid edges interpolate identically) for a clean
indexed mesh. requirements.txt no longer needs scikit-image.

Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096
Co-authored-by: Amp <amp@ampcode.com>
2026-06-14 23:44:20 -07:00
d8635dcb39 Cube3D: keep disable_offload=True (VQ decode needs full residency)
The VQ bottleneck reads raw parameters outside any hooked forward, so the
streaming-offload cast hooks cannot relocate them and decode fails with a
device mismatch under partial load. disable_offload is the standard
declarative flag for VAEs that need full residency (audio VAEs do the same),
and the decode still flows through the managed comfy.sd.VAE.decode path.

Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096
Co-authored-by: Amp <amp@ampcode.com>
2026-06-14 23:31:41 -07:00
aeb3c77ae9 Cube3D: route VAE decode through managed comfy.sd.VAE.decode
Stop fighting ComfyUI's model management. VAEDecodeCube was manually
calling load_models_gpu + .to(vae.device) and the VAE forced
disable_offload=True because it bypassed the managed decode path.

Now CubeShapeVAE.decode(samples) is the entry point that comfy.sd.VAE.decode
calls, so loading/device/dtype are handled automatically (like Hunyuan3Dv2):
- removed disable_offload=True (let the offload system manage weights)
- removed manual load_models_gpu + .to(device) from the node
- process_output set to identity (default clamps [0,1] in-place and would
  destroy the occupancy isosurface)
- decode() pre-inverts VAE.decode's trailing movedim(1,-1) so the node
  receives grid logits unchanged (parity preserved)
- memory_used_decode sized by num_tokens (shape[-1]) for the new latent layout

Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096
Co-authored-by: Amp <amp@ampcode.com>
2026-06-14 23:28:22 -07:00
a6c7397b71 Cube3D: use channels-first 1D latent (B,1,L) like Hunyuan3Dv2
Replaces the dummy trailing-dim latent with a channels-first 1D latent
(B, 1, num_tokens) and a dedicated latent_formats.Cube3D
(latent_channels=1, latent_dimensions=1). This mirrors the existing
native 3D model Hunyuan3Dv2's (B, C, L) convention and avoids
fix_empty_latent_channels truncating the token sequence (it narrows
dim=1 to latent_channels for empty latents). Requires no core sampler
changes: encode_model_conds sees a valid noise.shape[2].

- latent_formats.Cube3D added; wired into supported_models.Cube3D
- EmptyCubeLatent emits (B, 1, num_tokens)
- sample_cube takes T from x.shape[-1], returns (B, 1, T), and repeats
  conditioning to the latent batch size

Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096
Co-authored-by: Amp <amp@ampcode.com>
2026-06-14 23:14:17 -07:00
871f7bc390 Cube3D: fix graph integration (3D latent, VAE device, fp32 cond, scikit-image)
Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096
Co-authored-by: Amp <amp@ampcode.com>
2026-06-14 22:59:11 -07:00
01a8783bee Add native Roblox Cube3D text-to-3D support
Cube3D is an autoregressive VQ-token shape model (DualStreamRoformer) plus a
VQ-VAE shape tokenizer (OneDAutoEncoder), not a diffusion model. It is wired
natively following the Causal-WAN AR-video pattern: the GPT loads as a normal
MODEL and generation runs through a dedicated 'cube' sampler instead of KSampler.

- comfy/ldm/cube/gpt.py: DualStreamRoformer port (dual-stream RoPE attention,
  per-head RMSNorm, SwiGLU, KV cache; rope_theta=10000).
- comfy/ldm/cube/vae.py: OneDAutoEncoder decode path (codebook lookup, decoder,
  occupancy decoder, dense-grid extraction + skimage marching cubes).
- model_detection/supported_models/model_base: register shape_gpt as Cube3D MODEL
  (dims inferred from state dict; apply_model guarded to point at SamplerCube).
- sd.py: detect shape_tokenizer and build CubeShapeVAE.
- k_diffusion/sampling.py: sample_cube autoregressive sampler (decaying CFG +
  optional top-p), faithful to upstream Engine.run_gpt.
- comfy_extras/nodes_cube.py: EmptyCubeLatent, CubeCodebookPatch (inject VQ
  codebook into wte), SamplerCube, VAEDecodeCube (-> MESH).

Reuses CLIP-L conditioning, CFGGuider/SamplerCustomAdvanced, and SaveGLB.

Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096
Co-authored-by: Amp <amp@ampcode.com>
2026-06-14 20:21:37 -07:00
13 changed files with 1544 additions and 64 deletions

View File

@ -1955,3 +1955,120 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
transformer_options.pop("ar_state", None)
return output
def _cube_process_logits(logits, top_p, generator):
"""Token selection. top_p>=1 or <=0 -> greedy argmax (upstream default, deterministic)."""
if top_p is None or top_p >= 1.0 or top_p <= 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
remove = sorted_logits.softmax(dim=-1).cumsum(dim=-1) > top_p
remove[..., 0] = False
idx_remove = remove.scatter(-1, sorted_idx, remove)
logits = logits.masked_fill(idx_remove, float("-inf"))
probs = torch.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=generator)
@torch.no_grad()
def sample_cube(model, x, sigmas, extra_args=None, callback=None, disable=None, top_p=1.0):
"""
Autoregressive sampler for Roblox Cube3D shape GPT (DualStreamRoformer).
Not a diffusion sampler: the noised input `x` and `sigmas` values are ignored;
only x's shape (batch, 1, num_tokens) is used. Generates a 1024-long sequence of VQ
token IDs from CLIP text conditioning, with upstream's linearly-decaying CFG and
optional top-p. Plugs into SamplerCustomAdvanced via the SamplerCube node.
Faithful to cube3d.inference.engine.Engine.run_gpt:
gamma_i = cfg * (T - i) / T ; logits = (1+gamma)*cond - gamma*uncond
fp32 weights + bf16 autocast on cuda.
"""
import comfy.model_management
extra_args = {} if extra_args is None else extra_args
guider = model.inner_model # CFGGuider
base_model = guider.inner_model # BaseModel (Cube3D)
cube = base_model.diffusion_model
cfg = getattr(guider, "cfg", 3.0)
def get_cond(name):
conds = guider.conds.get(name, None)
if not conds:
return None
return conds[0]["model_conds"]["c_crossattn"].cond
pos = get_cond("positive")
neg = get_cond("negative")
if pos is None:
raise ValueError("sample_cube requires positive conditioning (CLIP-L text embeds).")
device = x.device
weight_dtype = base_model.get_dtype()
T = x.shape[-1] # sequence length; latent is (batch, 1, num_tokens)
batch = x.shape[0]
import comfy.utils
pos = comfy.utils.repeat_to_batch_size(pos, batch)
if neg is not None:
neg = comfy.utils.repeat_to_batch_size(neg, batch)
use_cfg = (cfg is not None) and (cfg > 0.0) and (neg is not None)
autocast_enabled = (device.type == "cuda")
cache_dtype = torch.bfloat16 if autocast_enabled else weight_dtype
def add_bbox(c):
if not getattr(cube, "use_bbox", False):
return c
bbox = torch.zeros((c.shape[0], 3), device=device, dtype=c.dtype)
return torch.cat([c, cube.bbox_proj(bbox).unsqueeze(1)], dim=1)
# Conditioning (text_proj + bbox_proj) is computed in the model's weight dtype
# OUTSIDE the bf16 autocast block, matching upstream cube's Engine.prepare_inputs
# (run_clip/encode_text run in full precision). The autocast only covers the
# autoregressive transformer forward, exactly like Engine.run_gpt.
cond = add_bbox(cube.encode_text(pos.to(device=device, dtype=weight_dtype)))
if use_cfg:
ucond = add_bbox(cube.encode_text(neg.to(device=device, dtype=weight_dtype)))
cond = torch.cat([cond, ucond], dim=0)
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
bos = torch.full((cond.shape[0], 1), cube.shape_bos_id, dtype=torch.long, device=device)
embed = cube.encode_token(bos)
Bp, input_seq_len, dim = embed.shape
embed_buffer = torch.zeros((Bp, input_seq_len + T, dim), dtype=embed.dtype, device=device)
embed_buffer[:, :input_seq_len, :].copy_(embed)
kv_cache = cube.init_kv_cache(Bp, cond.shape[1], T + 1, cache_dtype, device)
num_codes = cube.vocab_size - 3
seed = extra_args.get("seed", 0)
generator = None
if device.type != "mps":
generator = torch.Generator(device=device).manual_seed(int(seed))
output_ids = []
for i in trange(T, disable=disable):
comfy.model_management.throw_exception_if_processing_interrupted()
curr_pos_id = torch.tensor([i], dtype=torch.long, device=device)
logits = cube(embed_buffer, cond, kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=(i > 0))
logits = logits[:, 0, :num_codes]
if use_cfg:
cond_logits, uncond_logits = logits.float().chunk(2, dim=0)
gamma = cfg * (T - i) / T
logits = (1.0 + gamma) * cond_logits - gamma * uncond_logits
else:
logits = logits.float()
next_id = _cube_process_logits(logits, top_p, generator)
output_ids.append(next_id)
next_embed = cube.encode_token(next_id)
if use_cfg:
next_embed = torch.cat([next_embed, next_embed], dim=0)
embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[0], "sigma_hat": sigmas[0], "denoised": x})
# (B, T) token IDs -> (B, 1, T) to keep the channels-first 1D latent layout.
return torch.cat(output_ids, dim=1).to(torch.float32).unsqueeze(1)

View File

@ -775,6 +775,16 @@ class Hunyuan3Dv2mini(LatentFormat):
latent_dimensions = 1
scale_factor = 1.0188137142395404
class Cube3D(LatentFormat):
# Roblox Cube3D shape "latent" is a flat sequence of VQ token IDs (one scalar per
# position), so it maps to a channels-first 1D latent (B, 1, num_tokens), mirroring
# Hunyuan3Dv2's (B, C, L) convention. latent_channels=1 keeps fix_empty_latent_channels
# from truncating the token sequence. scale_factor=1.0 since IDs must pass through
# process_latent_in/out unchanged.
latent_channels = 1
latent_dimensions = 1
scale_factor = 1.0
class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2

417
comfy/ldm/cube/gpt.py Normal file
View File

@ -0,0 +1,417 @@
"""
Native port of Roblox/cube's shape GPT (DualStreamRoformer).
Reference: https://github.com/Roblox/cube (cube3d/model/gpt/dual_stream_roformer.py
and cube3d/model/transformers/*).
This is an autoregressive transformer over discrete VQ shape tokens, conditioned on
CLIP text embeddings. It is NOT a diffusion model; it is driven by the dedicated
`sample_cube` sampler (see comfy/k_diffusion/sampling.py), not KSampler.
The forward pass is kept faithful to upstream so token IDs match bit-for-bit:
* rope_theta = 10000
* per-head RMSNorm on Q and K
* dual-stream (MM-DiT style) joint attention; last dual block is cond_pre_only
* two separate RoPE frequency tensors (dual blocks offset cond tokens by S)
* SwiGLU MLP, non-affine LayerNorm upcast to fp32
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Norms (faithful to cube3d/model/transformers/norm.py)
# ---------------------------------------------------------------------------
class CubeLayerNorm(nn.Module):
"""Non-affine LayerNorm that upcasts to fp32 then back (matches cube)."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.dim = (dim,)
self.eps = eps
def forward(self, x):
y = F.layer_norm(x.float(), self.dim, None, None, self.eps)
return y.type_as(x)
class CubeRMSNorm(nn.Module):
"""Per-head RMSNorm with learnable weight, computed in fp32 (matches cube)."""
def __init__(self, dim, eps=1e-5, dtype=None, device=None):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
def forward(self, x):
xf = x.float()
out = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
return (out * self.weight).type_as(x)
# ---------------------------------------------------------------------------
# RoPE (faithful to cube3d/model/transformers/rope.py)
# ---------------------------------------------------------------------------
def apply_rotary_emb(x, freqs_cis, curr_pos_id=None):
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
if curr_pos_id is None:
freqs_cis = freqs_cis[:, -x.shape[2]:].unsqueeze(1)
else:
freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1)
y = torch.view_as_real(x_ * freqs_cis).flatten(3)
return y.type_as(x)
def precompute_freqs_cis(dim, t, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim))
freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1)
return torch.polar(torch.ones_like(freqs), freqs)
def sdpa_with_rope(q, k, v, freqs_cis, attn_mask=None, curr_pos_id=None, is_causal=False):
q = apply_rotary_emb(q, freqs_cis, curr_pos_id=curr_pos_id)
k = apply_rotary_emb(k, freqs_cis, curr_pos_id=None)
return F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0,
is_causal=is_causal and attn_mask is None,
)
# ---------------------------------------------------------------------------
# KV cache
# ---------------------------------------------------------------------------
class Cache:
def __init__(self, key_states, value_states):
self.key_states = key_states
self.value_states = value_states
def update(self, curr_pos_id, k, v):
self.key_states.index_copy_(2, curr_pos_id, k)
self.value_states.index_copy_(2, curr_pos_id, v)
# ---------------------------------------------------------------------------
# Shared building blocks
# ---------------------------------------------------------------------------
class SwiGLUMLP(nn.Module):
def __init__(self, embed_dim, hidden_dim, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.gate_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
self.up_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
self.down_proj = operations.Linear(hidden_dim, embed_dim, bias=bias, dtype=dtype, device=device)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class SelfAttentionWithRotaryEmbedding(nn.Module):
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
super().__init__()
assert embed_dim % num_heads == 0
self.num_heads = num_heads
head_dim = embed_dim // num_heads
self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device)
self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
def forward(self, x, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False):
b, l, d = x.shape
q, k = self.c_qk(x).chunk(2, dim=-1)
v = self.c_v(x)
q = q.view(b, l, self.num_heads, -1).transpose(1, 2)
k = k.view(b, l, self.num_heads, -1).transpose(1, 2)
v = v.view(b, l, self.num_heads, -1).transpose(1, 2)
q = self.q_norm(q)
k = self.k_norm(k)
if kv_cache is not None:
if not decode:
kv_cache.key_states[:, :, :k.shape[2], :].copy_(k)
kv_cache.value_states[:, :, :k.shape[2], :].copy_(v)
else:
kv_cache.update(curr_pos_id, k, v)
k = kv_cache.key_states
v = kv_cache.value_states
y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask,
curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal)
y = y.transpose(1, 2).contiguous().view(b, l, d)
return self.c_proj(y)
class DecoderLayerWithRotaryEmbedding(nn.Module):
"""Single-stream decoder layer (shape tokens only)."""
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
super().__init__()
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
self.attn = SelfAttentionWithRotaryEmbedding(embed_dim, num_heads, bias=bias, eps=eps,
dtype=dtype, device=device, operations=operations)
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations)
def forward(self, x, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False):
x = x + self.attn(self.ln_1(x), freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal,
kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode)
x = x + self.mlp(self.ln_2(x))
return x
# ---------------------------------------------------------------------------
# Dual-stream blocks (faithful to dual_stream_attention.py)
# ---------------------------------------------------------------------------
class DismantledPreAttention(nn.Module):
def __init__(self, embed_dim, num_heads, query=True, bias=True, dtype=None, device=None, operations=None):
super().__init__()
assert embed_dim % num_heads == 0
self.query = query
head_dim = embed_dim // num_heads
if query:
self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device)
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
else:
self.c_k = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
self.num_heads = num_heads
def _to_mha(self, x):
return x.view(*x.shape[:2], self.num_heads, -1).transpose(1, 2)
def forward(self, x):
if self.query:
q, k = self.c_qk(x).chunk(2, dim=-1)
q = self.q_norm(self._to_mha(q))
else:
q = None
k = self.c_k(x)
k = self.k_norm(self._to_mha(k))
v = self._to_mha(self.c_v(x))
return (q, k, v)
class DismantledPostAttention(nn.Module):
def __init__(self, embed_dim, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
super().__init__()
self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
self.ln_3 = CubeLayerNorm(embed_dim, eps=eps)
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations)
def forward(self, x, a):
x = x + self.c_proj(a)
x = x + self.mlp(self.ln_3(x))
return x
class DualStreamAttentionWithRotaryEmbedding(nn.Module):
def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.cond_pre_only = cond_pre_only
self.pre_x = DismantledPreAttention(embed_dim, num_heads, query=True, bias=bias,
dtype=dtype, device=device, operations=operations)
self.pre_c = DismantledPreAttention(embed_dim, num_heads, query=not cond_pre_only, bias=bias,
dtype=dtype, device=device, operations=operations)
def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False):
if kv_cache is None or not decode:
qkv_c = self.pre_c(c)
qkv_x = self.pre_x(x)
if self.cond_pre_only:
q = qkv_x[0]
else:
q = torch.cat([qkv_c[0], qkv_x[0]], dim=2)
k = torch.cat([qkv_c[1], qkv_x[1]], dim=2)
v = torch.cat([qkv_c[2], qkv_x[2]], dim=2)
else:
is_causal = False
q, k, v = self.pre_x(x)
if kv_cache is not None:
if not decode:
kv_cache.key_states[:, :, :k.shape[2], :].copy_(k)
kv_cache.value_states[:, :, :k.shape[2], :].copy_(v)
else:
kv_cache.update(curr_pos_id, k, v)
k = kv_cache.key_states
v = kv_cache.value_states
if attn_mask is not None:
if decode:
attn_mask = attn_mask[..., curr_pos_id, :]
else:
attn_mask = attn_mask[..., -q.shape[2]:, :]
y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask,
curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal)
y = y.transpose(1, 2).contiguous().view(x.shape[0], -1, x.shape[2])
if y.shape[1] == x.shape[1]:
return y, None
y_c, y_x = torch.split(y, [c.shape[1], x.shape[1]], dim=1)
return y_x, y_c
class DualStreamDecoderLayerWithRotaryEmbedding(nn.Module):
def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, eps=1e-6,
dtype=None, device=None, operations=None):
super().__init__()
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
self.attn = DualStreamAttentionWithRotaryEmbedding(embed_dim, num_heads, cond_pre_only=cond_pre_only,
bias=bias, dtype=dtype, device=device, operations=operations)
self.post_1 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations)
if not cond_pre_only:
self.post_2 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations)
def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False):
a_x, a_c = self.attn(
self.ln_1(x),
self.ln_2(c) if c is not None else None,
freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal,
kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode,
)
x = self.post_1(x, a_x)
if a_c is not None:
c = self.post_2(c, a_c)
else:
c = None
return x, c
# ---------------------------------------------------------------------------
# DualStreamRoformer
# ---------------------------------------------------------------------------
class DualStreamRoformer(nn.Module):
def __init__(
self,
n_layer=23,
n_single_layer=1,
rope_theta=10000,
n_head=12,
n_embd=1536,
bias=True,
eps=1e-6,
shape_model_vocab_size=16384,
shape_model_embed_dim=32,
text_model_embed_dim=768,
use_bbox=True,
image_model=None, # detection key; unused
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.n_layer = n_layer
self.n_single_layer = n_single_layer
self.n_head = n_head
self.n_embd = n_embd
self.rope_theta = rope_theta
self.head_dim = n_embd // n_head
self.text_proj = operations.Linear(text_model_embed_dim, n_embd, bias=bias, dtype=dtype, device=device)
self.shape_proj = operations.Linear(shape_model_embed_dim, n_embd, bias=True, dtype=dtype, device=device)
self.vocab_size = shape_model_vocab_size
self.shape_bos_id = self.vocab_size
self.shape_eos_id = self.vocab_size + 1
self.padding_id = self.vocab_size + 2
self.vocab_size += 3
self.transformer = nn.ModuleDict(dict(
wte=operations.Embedding(self.vocab_size, n_embd, padding_idx=self.padding_id, dtype=dtype, device=device),
dual_blocks=nn.ModuleList([
DualStreamDecoderLayerWithRotaryEmbedding(
n_embd, n_head, cond_pre_only=(i == n_layer - 1), bias=bias, eps=eps,
dtype=dtype, device=device, operations=operations,
)
for i in range(n_layer)
]),
single_blocks=nn.ModuleList([
DecoderLayerWithRotaryEmbedding(n_embd, n_head, bias=bias, eps=eps,
dtype=dtype, device=device, operations=operations)
for _ in range(n_single_layer)
]),
ln_f=CubeLayerNorm(n_embd, eps=eps),
))
self.lm_head = operations.Linear(n_embd, self.vocab_size, bias=False, dtype=dtype, device=device)
self.use_bbox = use_bbox
if use_bbox:
self.bbox_proj = operations.Linear(3, n_embd, bias=True, dtype=dtype, device=device)
def encode_text(self, text_embed):
return self.text_proj(text_embed)
def encode_token(self, tokens):
return self.transformer.wte(tokens)
def init_kv_cache(self, batch_size, cond_len, max_shape_tokens, dtype, device):
max_all = cond_len + max_shape_tokens
kv = [
Cache(
torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device),
torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device),
)
for _ in range(len(self.transformer.dual_blocks))
]
kv += [
Cache(
torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device),
torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device),
)
for _ in range(len(self.transformer.single_blocks))
]
return kv
def forward(self, embed, cond, kv_cache=None, curr_pos_id=None, decode=False):
b, l = embed.shape[:2]
s = cond.shape[1]
device = embed.device
attn_mask = torch.tril(torch.ones(s + l, s + l, dtype=torch.bool, device=device))
position_ids = torch.arange(l, dtype=torch.long, device=device).unsqueeze(0).expand(b, -1)
s_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta)
position_ids = torch.cat([
torch.zeros([b, s], dtype=torch.long, device=device),
position_ids,
], dim=1)
d_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta)
if kv_cache is not None and decode:
embed = embed[:, curr_pos_id, :]
h = embed
c = cond
layer_idx = 0
for block in self.transformer.dual_blocks:
h, c = block(
h, c=c, freqs_cis=d_freqs_cis, attn_mask=attn_mask, is_causal=True,
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None,
decode=decode,
)
layer_idx += 1
for block in self.transformer.single_blocks:
h = block(
h, freqs_cis=s_freqs_cis, attn_mask=None, is_causal=True,
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
curr_pos_id=curr_pos_id, decode=decode,
)
layer_idx += 1
h = self.transformer.ln_f(h)
return self.lm_head(h)

View File

@ -0,0 +1,379 @@
"""Dependency-free marching cubes (classic Lorensen/Cline) in pure PyTorch.
Vendored so Cube3D mesh extraction needs no scikit-image. This is the same
algorithm family as upstream cube's default NVIDIA-warp backend (warp.MarchingCubes),
so geometry is closer to the upstream default than skimage's Lewiner fallback.
Output convention matches skimage.measure.marching_cubes: vertices are returned in
array-index coordinates (axis 0, axis 1, axis 2 of the input volume), so the caller's
`vertices / grid_size * bbox_size + bbox_min` transform applies unchanged.
The standard 256-entry triangle table (Paul Bourke / Cory Bloyd) is used with the
canonical corner and edge numbering:
corners (x, y, z): edges (corner pairs):
0: (0,0,0) 1: (1,0,0) 0:0-1 1:1-2 2:2-3 3:3-0
2: (1,1,0) 3: (0,1,0) 4:4-5 5:5-6 6:6-7 7:7-4
4: (0,0,1) 5: (1,0,1) 8:0-4 9:1-5 10:2-6 11:3-7
6: (1,1,1) 7: (0,1,1)
Here x maps to volume axis 0, y to axis 1, z to axis 2.
"""
import numpy as np
import torch
# Corner offsets in (axis0, axis1, axis2) for the 8 cube corners.
_CORNERS = np.array([
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0],
[0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1],
], dtype=np.int64)
# The two corner indices that each of the 12 edges connects.
_EDGE_CORNERS = np.array([
[0, 1], [1, 2], [2, 3], [3, 0],
[4, 5], [5, 6], [6, 7], [7, 4],
[0, 4], [1, 5], [2, 6], [3, 7],
], dtype=np.int64)
# Standard 256 x 16 triangle table. For cube configuration `i`, lists triples of
# edge indices forming triangles, terminated by -1.
_TRI_TABLE = [
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[1, 8, 3, 9, 8, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 8, 3, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[9, 2, 10, 0, 2, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[2, 8, 3, 2, 10, 8, 10, 9, 8, -1, -1, -1, -1, -1, -1, -1],
[3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 11, 2, 8, 11, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[1, 9, 0, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[1, 11, 2, 1, 9, 11, 9, 8, 11, -1, -1, -1, -1, -1, -1, -1],
[3, 10, 1, 11, 10, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 10, 1, 0, 8, 10, 8, 11, 10, -1, -1, -1, -1, -1, -1, -1],
[3, 9, 0, 3, 11, 9, 11, 10, 9, -1, -1, -1, -1, -1, -1, -1],
[9, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[4, 3, 0, 7, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 1, 9, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[4, 1, 9, 4, 7, 1, 7, 3, 1, -1, -1, -1, -1, -1, -1, -1],
[1, 2, 10, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[3, 4, 7, 3, 0, 4, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1],
[9, 2, 10, 9, 0, 2, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1],
[2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4, -1, -1, -1, -1],
[8, 4, 7, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[11, 4, 7, 11, 2, 4, 2, 0, 4, -1, -1, -1, -1, -1, -1, -1],
[9, 0, 1, 8, 4, 7, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1],
[4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1, -1, -1, -1, -1],
[3, 10, 1, 3, 11, 10, 7, 8, 4, -1, -1, -1, -1, -1, -1, -1],
[1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4, -1, -1, -1, -1],
[4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3, -1, -1, -1, -1],
[4, 7, 11, 4, 11, 9, 9, 11, 10, -1, -1, -1, -1, -1, -1, -1],
[9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[9, 5, 4, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 5, 4, 1, 5, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[8, 5, 4, 8, 3, 5, 3, 1, 5, -1, -1, -1, -1, -1, -1, -1],
[1, 2, 10, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[3, 0, 8, 1, 2, 10, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1],
[5, 2, 10, 5, 4, 2, 4, 0, 2, -1, -1, -1, -1, -1, -1, -1],
[2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8, -1, -1, -1, -1],
[9, 5, 4, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 11, 2, 0, 8, 11, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1],
[0, 5, 4, 0, 1, 5, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1],
[2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5, -1, -1, -1, -1],
[10, 3, 11, 10, 1, 3, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1],
[4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10, -1, -1, -1, -1],
[5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3, -1, -1, -1, -1],
[5, 4, 8, 5, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1],
[9, 7, 8, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[9, 3, 0, 9, 5, 3, 5, 7, 3, -1, -1, -1, -1, -1, -1, -1],
[0, 7, 8, 0, 1, 7, 1, 5, 7, -1, -1, -1, -1, -1, -1, -1],
[1, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[9, 7, 8, 9, 5, 7, 10, 1, 2, -1, -1, -1, -1, -1, -1, -1],
[10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3, -1, -1, -1, -1],
[8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2, -1, -1, -1, -1],
[2, 10, 5, 2, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1],
[7, 9, 5, 7, 8, 9, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1],
[9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11, -1, -1, -1, -1],
[2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7, -1, -1, -1, -1],
[11, 2, 1, 11, 1, 7, 7, 1, 5, -1, -1, -1, -1, -1, -1, -1],
[9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11, -1, -1, -1, -1],
[5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0, -1],
[11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0, -1],
[11, 10, 5, 7, 11, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 8, 3, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[9, 0, 1, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[1, 8, 3, 1, 9, 8, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1],
[1, 6, 5, 2, 6, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[1, 6, 5, 1, 2, 6, 3, 0, 8, -1, -1, -1, -1, -1, -1, -1],
[9, 6, 5, 9, 0, 6, 0, 2, 6, -1, -1, -1, -1, -1, -1, -1],
[5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8, -1, -1, -1, -1],
[2, 3, 11, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[11, 0, 8, 11, 2, 0, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1],
[0, 1, 9, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1],
[5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11, -1, -1, -1, -1],
[6, 3, 11, 6, 5, 3, 5, 1, 3, -1, -1, -1, -1, -1, -1, -1],
[0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6, -1, -1, -1, -1],
[3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9, -1, -1, -1, -1],
[6, 5, 9, 6, 9, 11, 11, 9, 8, -1, -1, -1, -1, -1, -1, -1],
[5, 10, 6, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[4, 3, 0, 4, 7, 3, 6, 5, 10, -1, -1, -1, -1, -1, -1, -1],
[1, 9, 0, 5, 10, 6, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1],
[10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4, -1, -1, -1, -1],
[6, 1, 2, 6, 5, 1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1],
[1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7, -1, -1, -1, -1],
[8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6, -1, -1, -1, -1],
[7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9, -1],
[3, 11, 2, 7, 8, 4, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1],
[5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11, -1, -1, -1, -1],
[0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1],
[9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6, -1],
[8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6, -1, -1, -1, -1],
[5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11, -1],
[0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7, -1],
[6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9, -1, -1, -1, -1],
[10, 4, 9, 6, 4, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[4, 10, 6, 4, 9, 10, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1],
[10, 0, 1, 10, 6, 0, 6, 4, 0, -1, -1, -1, -1, -1, -1, -1],
[8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10, -1, -1, -1, -1],
[1, 4, 9, 1, 2, 4, 2, 6, 4, -1, -1, -1, -1, -1, -1, -1],
[3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4, -1, -1, -1, -1],
[0, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[8, 3, 2, 8, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1],
[10, 4, 9, 10, 6, 4, 11, 2, 3, -1, -1, -1, -1, -1, -1, -1],
[0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6, -1, -1, -1, -1],
[3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10, -1, -1, -1, -1],
[6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1, -1],
[9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3, -1, -1, -1, -1],
[8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1, -1],
[3, 11, 6, 3, 6, 0, 0, 6, 4, -1, -1, -1, -1, -1, -1, -1],
[6, 4, 8, 11, 6, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[7, 10, 6, 7, 8, 10, 8, 9, 10, -1, -1, -1, -1, -1, -1, -1],
[0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10, -1, -1, -1, -1],
[10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0, -1, -1, -1, -1],
[10, 6, 7, 10, 7, 1, 1, 7, 3, -1, -1, -1, -1, -1, -1, -1],
[1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7, -1, -1, -1, -1],
[2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9, -1],
[7, 8, 0, 7, 0, 6, 6, 0, 2, -1, -1, -1, -1, -1, -1, -1],
[7, 3, 2, 6, 7, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7, -1, -1, -1, -1],
[2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7, -1],
[1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11, -1],
[11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1, -1, -1, -1, -1],
[8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6, -1],
[0, 9, 1, 11, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0, -1, -1, -1, -1],
[7, 11, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[3, 0, 8, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 1, 9, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[8, 1, 9, 8, 3, 1, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1],
[10, 1, 2, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[1, 2, 10, 3, 0, 8, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1],
[2, 9, 0, 2, 10, 9, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1],
[6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8, -1, -1, -1, -1],
[7, 2, 3, 6, 2, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[7, 0, 8, 7, 6, 0, 6, 2, 0, -1, -1, -1, -1, -1, -1, -1],
[2, 7, 6, 2, 3, 7, 0, 1, 9, -1, -1, -1, -1, -1, -1, -1],
[1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6, -1, -1, -1, -1],
[10, 7, 6, 10, 1, 7, 1, 3, 7, -1, -1, -1, -1, -1, -1, -1],
[10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8, -1, -1, -1, -1],
[0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7, -1, -1, -1, -1],
[7, 6, 10, 7, 10, 8, 8, 10, 9, -1, -1, -1, -1, -1, -1, -1],
[6, 8, 4, 11, 8, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[3, 6, 11, 3, 0, 6, 0, 4, 6, -1, -1, -1, -1, -1, -1, -1],
[8, 6, 11, 8, 4, 6, 9, 0, 1, -1, -1, -1, -1, -1, -1, -1],
[9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6, -1, -1, -1, -1],
[6, 8, 4, 6, 11, 8, 2, 10, 1, -1, -1, -1, -1, -1, -1, -1],
[1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6, -1, -1, -1, -1],
[4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9, -1, -1, -1, -1],
[10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3, -1],
[8, 2, 3, 8, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1],
[0, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8, -1, -1, -1, -1],
[1, 9, 4, 1, 4, 2, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1],
[8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1, -1, -1, -1, -1],
[10, 1, 0, 10, 0, 6, 6, 0, 4, -1, -1, -1, -1, -1, -1, -1],
[4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3, -1],
[10, 9, 4, 6, 10, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[4, 9, 5, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 8, 3, 4, 9, 5, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1],
[5, 0, 1, 5, 4, 0, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1],
[11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5, -1, -1, -1, -1],
[9, 5, 4, 10, 1, 2, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1],
[6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5, -1, -1, -1, -1],
[7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2, -1, -1, -1, -1],
[3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6, -1],
[7, 2, 3, 7, 6, 2, 5, 4, 9, -1, -1, -1, -1, -1, -1, -1],
[9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7, -1, -1, -1, -1],
[3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0, -1, -1, -1, -1],
[6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8, -1],
[9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7, -1, -1, -1, -1],
[1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4, -1],
[4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10, -1],
[7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10, -1, -1, -1, -1],
[6, 9, 5, 6, 11, 9, 11, 8, 9, -1, -1, -1, -1, -1, -1, -1],
[3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5, -1, -1, -1, -1],
[0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11, -1, -1, -1, -1],
[6, 11, 3, 6, 3, 5, 5, 3, 1, -1, -1, -1, -1, -1, -1, -1],
[1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6, -1, -1, -1, -1],
[0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10, -1],
[11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5, -1],
[6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3, -1, -1, -1, -1],
[5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2, -1, -1, -1, -1],
[9, 5, 6, 9, 6, 0, 0, 6, 2, -1, -1, -1, -1, -1, -1, -1],
[1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8, -1],
[1, 5, 6, 2, 1, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6, -1],
[10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0, -1, -1, -1, -1],
[0, 3, 8, 5, 6, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[10, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[11, 5, 10, 7, 5, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[11, 5, 10, 11, 7, 5, 8, 3, 0, -1, -1, -1, -1, -1, -1, -1],
[5, 11, 7, 5, 10, 11, 1, 9, 0, -1, -1, -1, -1, -1, -1, -1],
[10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1, -1, -1, -1, -1],
[11, 1, 2, 11, 7, 1, 7, 5, 1, -1, -1, -1, -1, -1, -1, -1],
[0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11, -1, -1, -1, -1],
[9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7, -1, -1, -1, -1],
[7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2, -1],
[2, 5, 10, 2, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1],
[8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5, -1, -1, -1, -1],
[9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2, -1, -1, -1, -1],
[9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2, -1],
[1, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 8, 7, 0, 7, 1, 1, 7, 5, -1, -1, -1, -1, -1, -1, -1],
[9, 0, 3, 9, 3, 5, 5, 3, 7, -1, -1, -1, -1, -1, -1, -1],
[9, 8, 7, 5, 9, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[5, 8, 4, 5, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1],
[5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0, -1, -1, -1, -1],
[0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5, -1, -1, -1, -1],
[10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4, -1],
[2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8, -1, -1, -1, -1],
[0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11, -1],
[0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5, -1],
[9, 4, 5, 2, 11, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4, -1, -1, -1, -1],
[5, 10, 2, 5, 2, 4, 4, 2, 0, -1, -1, -1, -1, -1, -1, -1],
[3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9, -1],
[5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2, -1, -1, -1, -1],
[8, 4, 5, 8, 5, 3, 3, 5, 1, -1, -1, -1, -1, -1, -1, -1],
[0, 4, 5, 1, 0, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5, -1, -1, -1, -1],
[9, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[4, 11, 7, 4, 9, 11, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1],
[0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11, -1, -1, -1, -1],
[1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11, -1, -1, -1, -1],
[3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4, -1],
[4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2, -1, -1, -1, -1],
[9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3, -1],
[11, 7, 4, 11, 4, 2, 2, 4, 0, -1, -1, -1, -1, -1, -1, -1],
[11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4, -1, -1, -1, -1],
[2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9, -1, -1, -1, -1],
[9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7, -1],
[3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10, -1],
[1, 10, 2, 8, 7, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[4, 9, 1, 4, 1, 7, 7, 1, 3, -1, -1, -1, -1, -1, -1, -1],
[4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1, -1, -1, -1, -1],
[4, 0, 3, 7, 4, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[4, 8, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[9, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[3, 0, 9, 3, 9, 11, 11, 9, 10, -1, -1, -1, -1, -1, -1, -1],
[0, 1, 10, 0, 10, 8, 8, 10, 11, -1, -1, -1, -1, -1, -1, -1],
[3, 1, 10, 11, 3, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[1, 2, 11, 1, 11, 9, 9, 11, 8, -1, -1, -1, -1, -1, -1, -1],
[3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9, -1, -1, -1, -1],
[0, 2, 11, 8, 0, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[2, 3, 8, 2, 8, 10, 10, 8, 9, -1, -1, -1, -1, -1, -1, -1],
[9, 10, 2, 0, 9, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8, -1, -1, -1, -1],
[1, 10, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[1, 3, 8, 9, 1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 9, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[0, 3, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
]
@torch.no_grad()
def marching_cubes(volume: torch.Tensor, level: float = 0.0):
"""Extract an isosurface from a 3D scalar field.
Args:
volume: (D, H, W) float tensor. Inside is where ``volume < level`` (matches
the classic Lorensen convention and skimage's ``method="lorensen"``).
level: isosurface threshold.
Returns:
(vertices, faces): numpy arrays. ``vertices`` are float32 (N, 3) in array-index
coordinates (axis0, axis1, axis2); ``faces`` are int64 (M, 3).
"""
assert volume.ndim == 3, "volume must be (D, H, W)"
device = volume.device
vol = volume.float()
tri_table = torch.tensor(_TRI_TABLE, dtype=torch.long, device=device) # (256, 16)
edge_corners = torch.tensor(_EDGE_CORNERS, dtype=torch.long, device=device) # (12, 2)
corners = torch.tensor(_CORNERS, dtype=torch.float32, device=device) # (8, 3)
# Corner scalar values for every cell, shape (nc0, nc1, nc2, 8).
nc0, nc1, nc2 = vol.shape[0] - 1, vol.shape[1] - 1, vol.shape[2] - 1
if nc0 <= 0 or nc1 <= 0 or nc2 <= 0:
return (np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.int64))
corner_vals = torch.empty((nc0, nc1, nc2, 8), dtype=torch.float32, device=device)
for k in range(8):
o0, o1, o2 = _CORNERS[k]
corner_vals[..., k] = vol[o0:o0 + nc0, o1:o1 + nc1, o2:o2 + nc2]
# Cube configuration index: bit k set when corner k is inside (val < level).
inside = (corner_vals < level)
bits = torch.tensor([1 << k for k in range(8)], dtype=torch.long, device=device)
cube_index = (inside.long() * bits).sum(dim=-1) # (nc0, nc1, nc2)
# Cells that actually intersect the surface.
active = (cube_index > 0) & (cube_index < 255)
if not active.any():
return (np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.int64))
idx0, idx1, idx2 = torch.where(active) # (Nactive,)
cidx = cube_index[idx0, idx1, idx2] # (Nactive,)
cell_origin = torch.stack([idx0, idx1, idx2], dim=1).float() # (Nactive, 3)
cell_vals = corner_vals[idx0, idx1, idx2] # (Nactive, 8)
tris = tri_table[cidx] # (Nactive, 16)
# Each row holds up to 5 triangles (15 edge entries). Expand to (Nactive, 5, 3).
tri_edges = tris[:, :15].reshape(-1, 5, 3) # edge indices, -1 = unused
valid_tri = tri_edges[..., 0] >= 0 # (Nactive, 5)
cell_idx = torch.arange(cell_origin.shape[0], device=device).unsqueeze(1).expand(-1, 5)
cell_idx = cell_idx[valid_tri] # (T,)
edges = tri_edges[valid_tri] # (T, 3) edge index per triangle corner
# Interpolate a vertex on each referenced edge.
e_flat = edges.reshape(-1) # (T*3,)
cell_for_vert = cell_idx.unsqueeze(1).expand(-1, 3).reshape(-1) # (T*3,)
ca = edge_corners[e_flat, 0] # (T*3,) corner index a
cb = edge_corners[e_flat, 1] # corner index b
va = cell_vals[cell_for_vert, ca] # scalar at corner a
vb = cell_vals[cell_for_vert, cb]
pa = cell_origin[cell_for_vert] + corners[ca] # position of corner a (index space)
pb = cell_origin[cell_for_vert] + corners[cb]
denom = (vb - va)
t = torch.where(denom.abs() > 1e-12, (level - va) / denom, torch.zeros_like(denom))
t = t.clamp(0.0, 1.0).unsqueeze(1)
verts = pa + t * (pb - pa) # (T*3, 3) one vertex per triangle corner
# Weld shared vertices: a grid edge shared by adjacent cells interpolates to the exact
# same position (same corner values/positions), so exact dedup yields a clean indexed
# mesh like skimage/warp (one vertex per active edge).
uniq, inverse = torch.unique(verts, dim=0, return_inverse=True)
faces = inverse.reshape(-1, 3)
return (uniq.cpu().numpy().astype(np.float32), faces.cpu().numpy().astype(np.int64))

364
comfy/ldm/cube/vae.py Normal file
View File

@ -0,0 +1,364 @@
"""
Native port of Roblox/cube's shape tokenizer decode path (OneDAutoEncoder).
Reference: https://github.com/Roblox/cube (cube3d/model/autoencoder/*).
Only the DECODE path is ported (token IDs -> latents -> occupancy grid -> mesh);
the point-cloud encoder is not needed for text-to-3D generation. Encoder weights in
the checkpoint are loaded with strict=False and ignored.
Module/parameter names mirror upstream so the checkpoint loads directly:
embedder.weight
bottleneck.block.{codebook, cb_weight, cb_bias, c_in, c_x, c_out, ...}
decoder.{positional_encodings, blocks.N...}
occupancy_decoder.{query_in, attn_out, ln_f, c_head}
"""
import logging
import math
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
ops = comfy.ops.disable_weight_init
# ---------------------------------------------------------------------------
# Norms
# ---------------------------------------------------------------------------
class CubeLayerNorm(nn.Module):
"""LayerNorm upcasting to fp32. affine=False by default (no params)."""
def __init__(self, dim, eps=1e-6, elementwise_affine=False, dtype=None, device=None):
super().__init__()
self.dim = (dim,)
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
self.bias = nn.Parameter(torch.zeros(dim, dtype=dtype, device=device))
else:
self.weight = None
self.bias = None
def forward(self, x):
w = self.weight.float() if self.weight is not None else None
b = self.bias.float() if self.bias is not None else None
y = F.layer_norm(x.float(), self.dim, w, b, self.eps)
return y.type_as(x)
class CubeRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5, elementwise_affine=True, dtype=None, device=None):
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
else:
self.register_buffer("weight", torch.ones(dim), persistent=False)
def forward(self, x):
xf = x.float()
out = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
return (out * self.weight.float()).type_as(x)
# ---------------------------------------------------------------------------
# Fourier embedder
# ---------------------------------------------------------------------------
class PhaseModulatedFourierEmbedder(nn.Module):
def __init__(self, num_freqs, input_dim=3, dtype=None, device=None):
super().__init__()
self.weight = nn.Parameter(torch.empty(input_dim, num_freqs, dtype=dtype, device=device))
carrier = (num_freqs / 8) ** torch.linspace(1, 0, num_freqs)
carrier = (carrier + torch.linspace(0, 1, num_freqs)) * 2 * math.pi
self.register_buffer("carrier", carrier, persistent=False)
self.out_dim = input_dim * (num_freqs * 2 + 1)
def forward(self, x):
m = x.float().unsqueeze(-1)
w = self.weight.float()
carrier = self.carrier.float()
fm = (m * w).view(*x.shape[:-1], -1)
pm = (m * 0.5 * math.pi + carrier).view(*x.shape[:-1], -1)
return torch.cat([x, fm.cos() + pm.cos(), fm.sin() + pm.sin()], dim=-1).type_as(x)
# ---------------------------------------------------------------------------
# Attention building blocks
# ---------------------------------------------------------------------------
class MLP(nn.Module):
def __init__(self, embed_dim, hidden_dim, bias=True, dtype=None, device=None):
super().__init__()
self.up_proj = ops.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
self.down_proj = ops.Linear(hidden_dim, embed_dim, bias=bias, dtype=dtype, device=device)
self.act_fn = nn.GELU(approximate="none")
def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))
class SelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None):
super().__init__()
assert embed_dim % num_heads == 0
self.num_heads = num_heads
head_dim = embed_dim // num_heads
self.c_qk = ops.Linear(embed_dim, 2 * embed_dim, bias=bias, dtype=dtype, device=device)
self.c_v = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
self.c_proj = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
def forward(self, x, attn_mask=None, is_causal=False):
b, l, d = x.shape
q, k = self.c_qk(x).chunk(2, dim=-1)
v = self.c_v(x)
q = self.q_norm(q.view(b, l, self.num_heads, -1).transpose(1, 2))
k = self.k_norm(k.view(b, l, self.num_heads, -1).transpose(1, 2))
v = v.view(b, l, self.num_heads, -1).transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0,
is_causal=is_causal and attn_mask is None)
y = y.transpose(1, 2).contiguous().view(b, l, d)
return self.c_proj(y)
class CrossAttention(nn.Module):
def __init__(self, embed_dim, num_heads, q_dim=None, kv_dim=None, bias=True, dtype=None, device=None):
super().__init__()
assert embed_dim % num_heads == 0
q_dim = q_dim or embed_dim
kv_dim = kv_dim or embed_dim
self.c_q = ops.Linear(q_dim, embed_dim, bias=bias, dtype=dtype, device=device)
self.c_k = ops.Linear(kv_dim, embed_dim, bias=bias, dtype=dtype, device=device)
self.c_v = ops.Linear(kv_dim, embed_dim, bias=bias, dtype=dtype, device=device)
self.c_proj = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
self.num_heads = num_heads
def forward(self, x, c, attn_mask=None):
q, k, v = self.c_q(x), self.c_k(c), self.c_v(c)
b, l, d = q.shape
s = k.shape[1]
q = q.view(b, l, self.num_heads, -1).transpose(1, 2)
k = k.view(b, s, self.num_heads, -1).transpose(1, 2)
v = v.view(b, s, self.num_heads, -1).transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
y = y.transpose(1, 2).contiguous().view(b, l, d)
return self.c_proj(y)
class EncoderLayer(nn.Module):
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None):
super().__init__()
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
self.attn = SelfAttention(embed_dim, num_heads, bias=bias, eps=eps, dtype=dtype, device=device)
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
self.mlp = MLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device)
def forward(self, x, attn_mask=None, is_causal=False):
x = x + self.attn(self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal)
x = x + self.mlp(self.ln_2(x))
return x
class EncoderCrossAttentionLayer(nn.Module):
def __init__(self, embed_dim, num_heads, q_dim=None, kv_dim=None, bias=True, eps=1e-6, dtype=None, device=None):
super().__init__()
q_dim = q_dim or embed_dim
kv_dim = kv_dim or embed_dim
self.attn = CrossAttention(embed_dim, num_heads, q_dim=q_dim, kv_dim=kv_dim, bias=bias, dtype=dtype, device=device)
self.ln_1 = CubeLayerNorm(q_dim, eps=eps)
self.ln_2 = CubeLayerNorm(kv_dim, eps=eps)
self.ln_f = CubeLayerNorm(embed_dim, eps=eps)
self.mlp = MLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device)
def forward(self, x, c, attn_mask=None):
x = x + self.attn(self.ln_1(x), self.ln_2(c), attn_mask=attn_mask)
x = x + self.mlp(self.ln_f(x))
return x
class MLPEmbedder(nn.Module):
def __init__(self, in_dim, embed_dim, bias=True, dtype=None, device=None):
super().__init__()
self.in_layer = ops.Linear(in_dim, embed_dim, bias=bias, dtype=dtype, device=device)
self.silu = nn.SiLU()
self.out_layer = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
def forward(self, x):
return self.out_layer(self.silu(self.in_layer(x)))
# ---------------------------------------------------------------------------
# Spherical VQ (decode-only parts)
# ---------------------------------------------------------------------------
class SphericalVectorQuantizer(nn.Module):
def __init__(self, embed_dim, num_codes, width=None, dtype=None, device=None):
super().__init__()
self.num_codes = num_codes
self.codebook = ops.Embedding(num_codes, embed_dim, dtype=dtype, device=device)
width = width or embed_dim
if width != embed_dim:
self.c_in = ops.Linear(width, embed_dim, dtype=dtype, device=device)
self.c_x = ops.Linear(width, embed_dim, dtype=dtype, device=device)
self.c_out = ops.Linear(embed_dim, width, dtype=dtype, device=device)
else:
self.c_in = self.c_out = self.c_x = nn.Identity()
self.norm = CubeRMSNorm(embed_dim, elementwise_affine=False, dtype=dtype, device=device)
# "kl" codebook regularization (released config)
self.cb_weight = nn.Parameter(torch.ones([embed_dim], dtype=dtype, device=device))
self.cb_bias = nn.Parameter(torch.zeros([embed_dim], dtype=dtype, device=device))
def cb_norm(self, x):
return x * self.cb_weight + self.cb_bias
def get_codebook(self):
return self.norm(self.cb_norm(self.codebook.weight))
def lookup_codebook(self, q):
z_q = F.embedding(q, self.get_codebook())
return self.c_out(z_q)
class OneDBottleNeck(nn.Module):
def __init__(self, block):
super().__init__()
self.block = block
# ---------------------------------------------------------------------------
# Decoders
# ---------------------------------------------------------------------------
class OneDDecoder(nn.Module):
def __init__(self, num_latents, width, num_heads, num_layers, eps=1e-6, dtype=None, device=None):
super().__init__()
self.register_buffer("query", torch.empty([0, width]), persistent=False)
self.positional_encodings = nn.Parameter(torch.empty(num_latents, width, dtype=dtype, device=device))
self.blocks = nn.ModuleList([
EncoderLayer(width, num_heads, eps=eps, dtype=dtype, device=device)
for _ in range(num_layers)
])
def forward(self, z):
h = z + self.positional_encodings[:z.shape[1]].unsqueeze(0).to(z.dtype)
for block in self.blocks:
h = block(h)
return h
class OneDOccupancyDecoder(nn.Module):
def __init__(self, embedder, out_features, width, num_heads, eps=1e-6, dtype=None, device=None):
super().__init__()
self.embedder = embedder
self.query_in = MLPEmbedder(embedder.out_dim, width, dtype=dtype, device=device)
self.attn_out = EncoderCrossAttentionLayer(width, num_heads, dtype=dtype, device=device)
self.ln_f = CubeLayerNorm(width, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.c_head = ops.Linear(width, out_features, dtype=dtype, device=device)
def forward(self, queries, latents):
x = self.query_in(self.embedder(queries))
x = self.attn_out(x, latents)
return self.c_head(self.ln_f(x))
# ---------------------------------------------------------------------------
# Top-level shape VAE
# ---------------------------------------------------------------------------
def generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij"):
length = bbox_max - bbox_min
num_cells = np.exp2(resolution_base)
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
xs, ys, zs = np.meshgrid(x, y, z, indexing=indexing)
xyz = np.stack((xs, ys, zs), axis=-1).reshape(-1, 3)
grid_size = [int(num_cells) + 1] * 3
return xyz, grid_size, length
class CubeShapeVAE(nn.Module):
"""Decode-only OneDAutoEncoder. Encoder weights load with strict=False (ignored)."""
# Fixed query bounds for the occupancy grid (upstream default).
decode_bounds = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05)
def __init__(self, num_encoder_latents=1024, embed_dim=32, width=768, num_heads=12,
num_freqs=128, num_decoder_layers=24, num_codes=16384, out_dim=1, eps=1e-6,
dtype=None, device=None):
super().__init__()
self.cfg_num_encoder_latents = num_encoder_latents
self.cfg_num_codes = num_codes
self.embedder = PhaseModulatedFourierEmbedder(num_freqs=num_freqs, input_dim=3, dtype=dtype, device=device)
self.bottleneck = OneDBottleNeck(
SphericalVectorQuantizer(embed_dim, num_codes, width, dtype=dtype, device=device)
)
self.decoder = OneDDecoder(num_encoder_latents, width, num_heads, num_decoder_layers,
eps=eps, dtype=dtype, device=device)
self.occupancy_decoder = OneDOccupancyDecoder(self.embedder, out_dim, width, num_heads,
eps=eps, dtype=dtype, device=device)
@torch.no_grad()
def decode(self, samples, resolution_base=8.0, chunk_size=100_000, **kwargs):
"""Token IDs -> occupancy grid logits. Entry point for comfy.sd.VAE.decode, which
manages model loading/device/dtype. `samples` arrive as (B, 1, num_tokens) in the
VAE working dtype on the load device. VAE.decode applies a trailing movedim(1, -1),
so pre-invert it here to hand the node grid logits as (B, gx, gy, gz)."""
ids = samples.reshape(samples.shape[0], -1)[:, :self.cfg_num_encoder_latents]
ids = ids.round().long().clamp(0, self.cfg_num_codes - 1)
latents = self.decode_indices(ids)
grid_logits, _, _, _ = self.extract_geometry(
latents, bounds=self.decode_bounds, resolution_base=resolution_base, chunk_size=chunk_size)
return grid_logits.movedim(-1, 1)
@torch.no_grad()
def decode_indices(self, shape_ids):
z_q = self.bottleneck.block.lookup_codebook(shape_ids)
return self.decoder(z_q)
@torch.no_grad()
def query(self, queries, latents):
return self.occupancy_decoder(queries, latents).squeeze(-1)
@torch.no_grad()
def extract_geometry(self, latents, bounds=(-1.05, -1.05, -1.05, 1.05, 1.05, 1.05),
resolution_base=8.0, chunk_size=100_000):
bbox_min = np.array(bounds[0:3])
bbox_max = np.array(bounds[3:6])
bbox_size = bbox_max - bbox_min
xyz, grid_size, _ = generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij")
xyz = torch.from_numpy(xyz)
batch_size = latents.shape[0]
batch_logits = []
for start in range(0, xyz.shape[0], chunk_size):
queries = xyz[start:start + chunk_size, :]
n = queries.shape[0]
if start > 0 and n < chunk_size:
queries = F.pad(queries, [0, 0, 0, chunk_size - n])
bq = queries.unsqueeze(0).expand(batch_size, -1, -1).to(latents)
batch_logits.append(self.query(bq, latents)[:, :n])
grid_logits = torch.cat(batch_logits, dim=1).detach().view(
batch_size, grid_size[0], grid_size[1], grid_size[2]).float()
return grid_logits, grid_size, bbox_size, bbox_min
def grid_logits_to_mesh(grid_logit, grid_size, bbox_size, bbox_min, level=0.0):
"""Occupancy-logit grid -> mesh, using the vendored dependency-free marching cubes
(classic Lorensen, same family as upstream cube's default warp backend). Vertices are
rescaled from grid-index space into the bbox, matching upstream's transform."""
from comfy.ldm.cube.marching_cubes import marching_cubes
vertices, faces = marching_cubes(grid_logit, level)
vertices = vertices / np.array(grid_size) * bbox_size + bbox_min
# The vendored Lorensen table already emits outward-facing winding for this
# occupancy convention, so (unlike the upstream skimage path) no face flip is needed.
return vertices.astype(np.float32), np.ascontiguousarray(faces)

View File

@ -44,6 +44,7 @@ import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.wan.model_animate
import comfy.ldm.wan.ar_model
import comfy.ldm.cube.gpt
import comfy.ldm.wan.model_wandancer
import comfy.ldm.hunyuan3d.model
import comfy.ldm.triposplat.model
@ -1903,6 +1904,26 @@ class Hunyuan3Dv2(BaseModel):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class Cube3D(BaseModel):
"""Roblox Cube3D shape GPT (autoregressive). Generation goes through the
dedicated `cube` sampler (SamplerCustomAdvanced), never KSampler/apply_model."""
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cube.gpt.DualStreamRoformer)
def _apply_model(self, *args, **kwargs):
raise RuntimeError(
"Cube3D is an autoregressive token model. Use the 'cube' sampler "
"(SamplerCube + SamplerCustomAdvanced), not KSampler."
)
def extra_conds(self, **kwargs):
out = {}
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class Hunyuan3Dv2_1(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)

View File

@ -654,6 +654,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config
if '{}shape_proj.weight'.format(key_prefix) in state_dict_keys and '{}lm_head.weight'.format(key_prefix) in state_dict_keys: # Roblox Cube3D shape GPT
dit_config = {}
dit_config["image_model"] = "cube3d"
n_embd = state_dict['{}transformer.wte.weight'.format(key_prefix)].shape[1]
dit_config["n_embd"] = n_embd
dit_config["shape_model_vocab_size"] = state_dict['{}transformer.wte.weight'.format(key_prefix)].shape[0] - 3
dit_config["n_layer"] = count_blocks(state_dict_keys, '{}transformer.dual_blocks.'.format(key_prefix) + '{}.')
dit_config["n_single_layer"] = count_blocks(state_dict_keys, '{}transformer.single_blocks.'.format(key_prefix) + '{}.')
head_dim = state_dict['{}transformer.dual_blocks.0.attn.pre_x.q_norm.weight'.format(key_prefix)].shape[0]
dit_config["n_head"] = n_embd // head_dim
dit_config["shape_model_embed_dim"] = state_dict['{}shape_proj.weight'.format(key_prefix)].shape[1]
dit_config["text_model_embed_dim"] = state_dict['{}text_proj.weight'.format(key_prefix)].shape[1]
dit_config["use_bbox"] = '{}bbox_proj.weight'.format(key_prefix) in state_dict_keys
dit_config["bias"] = '{}text_proj.bias'.format(key_prefix) in state_dict_keys
dit_config["rope_theta"] = 10000 # not stored in the state dict; upstream's fixed constant
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
dit_config = {}

View File

@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.cube.vae
import comfy.ldm.triposplat.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae
@ -777,6 +778,39 @@ class VAE:
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# Roblox Cube3D shape tokenizer (OneDAutoEncoder, decode-only)
elif "bottleneck.block.codebook.weight" in sd:
self.latent_dim = 1
# The VQ bottleneck (get_codebook/lookup_codebook) reads raw parameters
# outside any hooked forward, so the streaming-offload cast hooks can't
# relocate them; the model must be fully resident to decode. This is a
# correctness requirement, declared via the standard flag (like the audio
# VAEs) rather than managed manually in the node.
self.disable_offload = True
embed_dim = sd["bottleneck.block.codebook.weight"].shape[1]
num_codes = sd["bottleneck.block.codebook.weight"].shape[0]
width = sd["bottleneck.block.c_out.weight"].shape[0]
num_encoder_latents = sd["decoder.positional_encodings"].shape[0]
head_dim = sd["decoder.blocks.0.attn.q_norm.weight"].shape[0]
num_heads = width // head_dim
num_freqs = sd["embedder.weight"].shape[1]
num_decoder_layers = len({k.split(".")[2] for k in sd if k.startswith("decoder.blocks.")})
self.first_stage_model = comfy.ldm.cube.vae.CubeShapeVAE(
num_encoder_latents=num_encoder_latents, embed_dim=embed_dim, width=width,
num_heads=num_heads, num_freqs=num_freqs, num_decoder_layers=num_decoder_layers,
num_codes=num_codes,
)
# Decode goes through the managed comfy.sd.VAE.decode path; the grid logits
# are float32 regardless of weight dtype, so keep process_output identity
# (the default clamps to [0, 1] in-place and would destroy the isosurface).
self.process_output = lambda image: image
self.process_input = lambda image: image
# shape is the token-ID latent (B, 1, num_tokens); size by num_tokens.
self.memory_used_decode = lambda shape, dtype: (1000 * shape[-1] * 768) * model_management.dtype_size(dtype)
# fp32-only (unlike most VAEs that allow fp16/bf16): the VQ codebook lookup
# and occupancy-grid query must run in fp32 to match upstream and keep the
# isosurface stable.
self.working_dtypes = [torch.float32]
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)

View File

@ -1550,6 +1550,32 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
latent_format = latent_formats.Hunyuan3Dv2mini
class Cube3D(supported_models_base.BASE):
unet_config = {
"image_model": "cube3d",
}
unet_extra_config = {}
sampling_settings = {}
latent_format = latent_formats.Cube3D
memory_usage_factor = 1.0
# Upstream keeps fp32 weights and uses bf16 autocast during the forward pass
# (see sample_cube). Prefer fp32 weights for parity; bf16 is the low-VRAM fallback.
supported_inference_dtypes = [torch.float32, torch.bfloat16]
def get_model(self, state_dict, prefix="", device=None):
return model_base.Cube3D(self, device=device)
def clip_target(self, state_dict={}):
# No bundled text encoder: the cube checkpoint is GPT-only. The graph wires a
# standard CLIPLoader(clip-l)/CLIPTextEncode, so there is no clip_target to build.
return None
class TripoSplat(supported_models_base.BASE):
# Image -> 3D gaussian splat flow denoiser
unet_config = {
@ -2292,6 +2318,7 @@ models = [
Hunyuan3Dv2mini,
Hunyuan3Dv2,
Hunyuan3Dv2_1,
Cube3D,
TripoSplat,
HiDream,
HiDreamO1,

156
comfy_extras/nodes_cube.py Normal file
View File

@ -0,0 +1,156 @@
"""
Nodes for native Roblox Cube3D text-to-3D support.
Graph:
CLIPLoader(clip-l) -> CLIPTextEncode -> CONDITIONING
UNETLoader(shape_gpt) -> MODEL --\
VAELoader(shape_tokenizer) -> VAE -> CubeCodebookPatch -> MODEL
CFGGuider(MODEL, pos, neg, cfg) + SamplerCube + (trivial sigmas) + EmptyCubeLatent
-> SamplerCustomAdvanced -> LATENT (token IDs)
VAEDecodeCube(VAE, LATENT) -> MESH -> SaveGLB
"""
import numpy as np
import torch
from typing_extensions import override
import comfy.ldm.cube.vae
import comfy.model_management
import comfy.samplers
from comfy_api.latest import ComfyExtension, IO, Types
from comfy_extras.nodes_save_3d import pack_variable_mesh_batch
class EmptyCubeLatent(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyCubeLatent",
category="latent/3d",
inputs=[
IO.Int.Input("num_tokens", default=1024, min=1, max=8192,
tooltip="Shape token sequence length. Must match the tokenizer "
"(1024 for cube3d-v0.5, 512 for v0.1)."),
IO.Int.Input("batch_size", default=1, min=1, max=64),
],
outputs=[IO.Latent.Output()],
)
@classmethod
def execute(cls, num_tokens, batch_size) -> IO.NodeOutput:
# Channels-first 1D latent (B, 1, num_tokens), mirroring Hunyuan3Dv2's (B, C, L)
# convention (latent_channels=1). The sampler only uses the sequence length.
latent = torch.zeros([batch_size, 1, num_tokens], device=comfy.model_management.intermediate_device())
return IO.NodeOutput({"samples": latent, "type": "cube_tokens"})
class CubeCodebookPatch(IO.ComfyNode):
"""Inject the projected VQ codebook into the GPT token-embedding table.
Upstream copies shape_proj(tokenizer.codebook) into wte.weight[:num_codes] at load
time; without it generation is garbage. Done here as a ModelPatcher object patch so
it composes with normal model loading/offload."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="CubeCodebookPatch",
display_name="Cube Codebook Patch",
category="advanced/model",
inputs=[
IO.Model.Input("model"),
IO.Vae.Input("vae"),
],
outputs=[IO.Model.Output()],
)
@classmethod
def execute(cls, model, vae) -> IO.NodeOutput:
gpt = model.get_model_object("diffusion_model")
codebook = vae.first_stage_model.bottleneck.block.get_codebook() # (num_codes, embed_dim) fp32
w = gpt.shape_proj.weight
proj = gpt.shape_proj(codebook.to(device=w.device, dtype=w.dtype)) # (num_codes, n_embd)
old = model.get_model_object("diffusion_model.transformer.wte.weight")
new = old.clone()
new[:proj.shape[0]] = proj.to(device=new.device, dtype=new.dtype)
m = model.clone()
m.add_object_patch("diffusion_model.transformer.wte.weight",
torch.nn.Parameter(new, requires_grad=False))
return IO.NodeOutput(m)
class SamplerCube(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SamplerCube",
display_name="Sampler Cube (autoregressive)",
category="sampling/custom_sampling/samplers",
inputs=[
IO.Float.Input("top_p", default=1.0, min=0.0, max=1.0, step=0.01,
tooltip="1.0 = deterministic greedy (upstream default). "
"<1.0 enables nucleus sampling."),
],
outputs=[IO.Sampler.Output()],
)
@classmethod
def execute(cls, top_p) -> IO.NodeOutput:
return IO.NodeOutput(comfy.samplers.ksampler("cube", {"top_p": top_p}))
class VAEDecodeCube(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeCube",
display_name="VAE Decode Cube (3D)",
category="latent/3d",
inputs=[
IO.Vae.Input("vae"),
IO.Latent.Input("samples"),
IO.Float.Input("resolution_base", default=8.0, min=4.0, max=10.0, step=0.5,
tooltip="Grid cells per axis = 2^resolution_base. 8.0 matches "
"upstream default (257^3 grid)."),
IO.Int.Input("chunk_size", default=100000, min=1000, max=2000000, advanced=True),
],
outputs=[IO.Mesh.Output()],
)
@classmethod
def execute(cls, vae, samples, resolution_base, chunk_size) -> IO.NodeOutput:
# Managed decode: comfy.sd.VAE.decode handles model loading + device/dtype and
# returns the occupancy grid logits (B, gx, gy, gz). Marching cubes runs here.
grid = vae.decode(samples["samples"],
vae_options={"resolution_base": resolution_base, "chunk_size": chunk_size})
bounds = vae.first_stage_model.decode_bounds
bbox_min = np.array(bounds[0:3])
bbox_size = np.array(bounds[3:6]) - bbox_min
grid_size = list(grid.shape[1:])
verts_list, faces_list = [], []
for i in range(grid.shape[0]):
v, f = comfy.ldm.cube.vae.grid_logits_to_mesh(grid[i], grid_size, bbox_size, bbox_min)
verts_list.append(torch.from_numpy(v))
faces_list.append(torch.from_numpy(f.astype(np.int64)))
mesh = pack_variable_mesh_batch(verts_list, faces_list)
return IO.NodeOutput(mesh)
class CubeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
EmptyCubeLatent,
CubeCodebookPatch,
SamplerCube,
VAEDecodeCube,
]
async def comfy_entrypoint() -> CubeExtension:
return CubeExtension()

View File

@ -317,74 +317,11 @@ class PreviewPointCloud(IO.ComfyNode):
)
MESH_EXTENSIONS = {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
class Load3DAdvanced(IO.ComfyNode):
@classmethod
def define_schema(cls):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in MESH_EXTENSIONS
]
return IO.Schema(
node_id="Load3DAdvanced",
display_name="Load 3D (Advanced)",
category="3d",
search_aliases=[
"load mesh",
"load gltf",
"load glb",
"load obj",
"load fbx",
"load stl",
],
is_experimental=True,
inputs=[
IO.Combo.Input("model_file", options=["none"] + sorted(files), upload=IO.UploadType.model),
IO.Load3D.Input("viewport_state"),
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
],
outputs=[
IO.File3DAny.Output(display_name="model_3d"),
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
IO.Load3DCamera.Output(display_name="camera_info"),
IO.Int.Output(display_name="width"),
IO.Int.Output(display_name="height"),
],
)
@classmethod
def validate_inputs(cls, model_file, **kwargs) -> bool | str:
if not model_file or model_file == "none":
return True
if not folder_paths.exists_annotated_filepath(model_file):
return f"Invalid 3D model file: {model_file}"
return True
@classmethod
def execute(cls, model_file, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
file_3d = None
if model_file and model_file != "none":
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
model_3d_info = viewport_state.get('model_3d_info', [])
return IO.NodeOutput(file_3d, model_3d_info, viewport_state['camera_info'], width, height)
class Load3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
Load3D,
Load3DAdvanced,
Preview3D,
Preview3DAdvanced,
PreviewGaussianSplat,

View File

@ -2433,6 +2433,7 @@ async def init_builtin_extra_nodes():
"nodes_kandinsky5.py",
"nodes_wanmove.py",
"nodes_ar_video.py",
"nodes_cube.py",
"nodes_image_compare.py",
"nodes_zimage.py",
"nodes_glsl.py",

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.45.15
comfyui-workflow-templates==0.9.98
comfyui-embedded-docs==0.5.4
comfyui-embedded-docs==0.5.3
torch
torchsde
torchvision