mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-16 03:46:06 +08:00
Compare commits
20 Commits
ci-failure
...
feat/cube3
| Author | SHA1 | Date | |
|---|---|---|---|
| e7f99168ae | |||
| 029b782936 | |||
| 81f5f84ad6 | |||
| d8635dcb39 | |||
| aeb3c77ae9 | |||
| a6c7397b71 | |||
| 871f7bc390 | |||
| 01a8783bee | |||
| 4388eb781a | |||
| e1b9366898 | |||
| 5897d0c3ae | |||
| a1d95f3f82 | |||
| 64cc078069 | |||
| 740d347279 | |||
| b664349ae7 | |||
| fe54b5e955 | |||
| 7277d99d3a | |||
| 28a40fb2b2 | |||
| d7a552720c | |||
| 02656ea0bb |
13
.github/workflows/test-ci.yml
vendored
13
.github/workflows/test-ci.yml
vendored
@ -97,16 +97,3 @@ jobs:
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
|
||||
notify-failure:
|
||||
needs: [test-stable, test-unix-nightly]
|
||||
if: ${{ failure() && github.event_name == 'push' }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Notify Slack of CI failure
|
||||
uses: slackapi/slack-github-action@v2.1.0
|
||||
with:
|
||||
webhook: ${{ secrets.CI_ALERTS_SLACK_WEBHOOK }}
|
||||
webhook-type: incoming-webhook
|
||||
payload: |
|
||||
text: ":rotating_siren: ComfyUI CI failed on `${{ github.ref_name }}`\n*Commit:* <${{ github.server_url }}/${{ github.repository }}/commit/${{ github.sha }}|${{ github.sha }}>\n*Run:* <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|view logs>\n*Dashboard:* <https://ci.comfy.org/?branch=${{ github.ref_name }}|ci.comfy.org>"
|
||||
|
||||
@ -382,11 +382,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
|
||||
|
||||
### AMD ROCm Tips
|
||||
|
||||
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
|
||||
|
||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
||||
|
||||
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
You can try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
|
||||
# Notes
|
||||
|
||||
|
||||
@ -115,6 +115,7 @@ cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metav
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--high-ram", action="store_true", help="Can improve performance slightly on high RAM or on systems where pagefile use is preferred over model loading.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@ -249,6 +250,9 @@ else:
|
||||
if args.cache_ram is not None and len(args.cache_ram) > 2:
|
||||
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
|
||||
|
||||
if args.high_ram:
|
||||
args.cache_classic = True
|
||||
|
||||
if args.windows_standalone_build:
|
||||
args.auto_launch = True
|
||||
|
||||
|
||||
@ -1955,3 +1955,120 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
transformer_options.pop("ar_state", None)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _cube_process_logits(logits, top_p, generator):
|
||||
"""Token selection. top_p>=1 or <=0 -> greedy argmax (upstream default, deterministic)."""
|
||||
if top_p is None or top_p >= 1.0 or top_p <= 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
|
||||
remove = sorted_logits.softmax(dim=-1).cumsum(dim=-1) > top_p
|
||||
remove[..., 0] = False
|
||||
idx_remove = remove.scatter(-1, sorted_idx, remove)
|
||||
logits = logits.masked_fill(idx_remove, float("-inf"))
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
return torch.multinomial(probs, num_samples=1, generator=generator)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_cube(model, x, sigmas, extra_args=None, callback=None, disable=None, top_p=1.0):
|
||||
"""
|
||||
Autoregressive sampler for Roblox Cube3D shape GPT (DualStreamRoformer).
|
||||
|
||||
Not a diffusion sampler: the noised input `x` and `sigmas` values are ignored;
|
||||
only x's shape (batch, 1, num_tokens) is used. Generates a 1024-long sequence of VQ
|
||||
token IDs from CLIP text conditioning, with upstream's linearly-decaying CFG and
|
||||
optional top-p. Plugs into SamplerCustomAdvanced via the SamplerCube node.
|
||||
|
||||
Faithful to cube3d.inference.engine.Engine.run_gpt:
|
||||
gamma_i = cfg * (T - i) / T ; logits = (1+gamma)*cond - gamma*uncond
|
||||
fp32 weights + bf16 autocast on cuda.
|
||||
"""
|
||||
import comfy.model_management
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
guider = model.inner_model # CFGGuider
|
||||
base_model = guider.inner_model # BaseModel (Cube3D)
|
||||
cube = base_model.diffusion_model
|
||||
cfg = getattr(guider, "cfg", 3.0)
|
||||
|
||||
def get_cond(name):
|
||||
conds = guider.conds.get(name, None)
|
||||
if not conds:
|
||||
return None
|
||||
return conds[0]["model_conds"]["c_crossattn"].cond
|
||||
|
||||
pos = get_cond("positive")
|
||||
neg = get_cond("negative")
|
||||
if pos is None:
|
||||
raise ValueError("sample_cube requires positive conditioning (CLIP-L text embeds).")
|
||||
|
||||
device = x.device
|
||||
weight_dtype = base_model.get_dtype()
|
||||
T = x.shape[-1] # sequence length; latent is (batch, 1, num_tokens)
|
||||
batch = x.shape[0]
|
||||
import comfy.utils
|
||||
pos = comfy.utils.repeat_to_batch_size(pos, batch)
|
||||
if neg is not None:
|
||||
neg = comfy.utils.repeat_to_batch_size(neg, batch)
|
||||
use_cfg = (cfg is not None) and (cfg > 0.0) and (neg is not None)
|
||||
autocast_enabled = (device.type == "cuda")
|
||||
cache_dtype = torch.bfloat16 if autocast_enabled else weight_dtype
|
||||
|
||||
def add_bbox(c):
|
||||
if not getattr(cube, "use_bbox", False):
|
||||
return c
|
||||
bbox = torch.zeros((c.shape[0], 3), device=device, dtype=c.dtype)
|
||||
return torch.cat([c, cube.bbox_proj(bbox).unsqueeze(1)], dim=1)
|
||||
|
||||
# Conditioning (text_proj + bbox_proj) is computed in the model's weight dtype
|
||||
# OUTSIDE the bf16 autocast block, matching upstream cube's Engine.prepare_inputs
|
||||
# (run_clip/encode_text run in full precision). The autocast only covers the
|
||||
# autoregressive transformer forward, exactly like Engine.run_gpt.
|
||||
cond = add_bbox(cube.encode_text(pos.to(device=device, dtype=weight_dtype)))
|
||||
if use_cfg:
|
||||
ucond = add_bbox(cube.encode_text(neg.to(device=device, dtype=weight_dtype)))
|
||||
cond = torch.cat([cond, ucond], dim=0)
|
||||
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
|
||||
bos = torch.full((cond.shape[0], 1), cube.shape_bos_id, dtype=torch.long, device=device)
|
||||
embed = cube.encode_token(bos)
|
||||
Bp, input_seq_len, dim = embed.shape
|
||||
embed_buffer = torch.zeros((Bp, input_seq_len + T, dim), dtype=embed.dtype, device=device)
|
||||
embed_buffer[:, :input_seq_len, :].copy_(embed)
|
||||
|
||||
kv_cache = cube.init_kv_cache(Bp, cond.shape[1], T + 1, cache_dtype, device)
|
||||
|
||||
num_codes = cube.vocab_size - 3
|
||||
seed = extra_args.get("seed", 0)
|
||||
generator = None
|
||||
if device.type != "mps":
|
||||
generator = torch.Generator(device=device).manual_seed(int(seed))
|
||||
|
||||
output_ids = []
|
||||
for i in trange(T, disable=disable):
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
curr_pos_id = torch.tensor([i], dtype=torch.long, device=device)
|
||||
logits = cube(embed_buffer, cond, kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=(i > 0))
|
||||
logits = logits[:, 0, :num_codes]
|
||||
|
||||
if use_cfg:
|
||||
cond_logits, uncond_logits = logits.float().chunk(2, dim=0)
|
||||
gamma = cfg * (T - i) / T
|
||||
logits = (1.0 + gamma) * cond_logits - gamma * uncond_logits
|
||||
else:
|
||||
logits = logits.float()
|
||||
|
||||
next_id = _cube_process_logits(logits, top_p, generator)
|
||||
output_ids.append(next_id)
|
||||
|
||||
next_embed = cube.encode_token(next_id)
|
||||
if use_cfg:
|
||||
next_embed = torch.cat([next_embed, next_embed], dim=0)
|
||||
embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
|
||||
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[0], "sigma_hat": sigmas[0], "denoised": x})
|
||||
|
||||
# (B, T) token IDs -> (B, 1, T) to keep the channels-first 1D latent layout.
|
||||
return torch.cat(output_ids, dim=1).to(torch.float32).unsqueeze(1)
|
||||
|
||||
@ -775,6 +775,16 @@ class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_dimensions = 1
|
||||
scale_factor = 1.0188137142395404
|
||||
|
||||
class Cube3D(LatentFormat):
|
||||
# Roblox Cube3D shape "latent" is a flat sequence of VQ token IDs (one scalar per
|
||||
# position), so it maps to a channels-first 1D latent (B, 1, num_tokens), mirroring
|
||||
# Hunyuan3Dv2's (B, C, L) convention. latent_channels=1 keeps fix_empty_latent_channels
|
||||
# from truncating the token sequence. scale_factor=1.0 since IDs must pass through
|
||||
# process_latent_in/out unchanged.
|
||||
latent_channels = 1
|
||||
latent_dimensions = 1
|
||||
scale_factor = 1.0
|
||||
|
||||
class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
417
comfy/ldm/cube/gpt.py
Normal file
417
comfy/ldm/cube/gpt.py
Normal file
@ -0,0 +1,417 @@
|
||||
"""
|
||||
Native port of Roblox/cube's shape GPT (DualStreamRoformer).
|
||||
|
||||
Reference: https://github.com/Roblox/cube (cube3d/model/gpt/dual_stream_roformer.py
|
||||
and cube3d/model/transformers/*).
|
||||
|
||||
This is an autoregressive transformer over discrete VQ shape tokens, conditioned on
|
||||
CLIP text embeddings. It is NOT a diffusion model; it is driven by the dedicated
|
||||
`sample_cube` sampler (see comfy/k_diffusion/sampling.py), not KSampler.
|
||||
|
||||
The forward pass is kept faithful to upstream so token IDs match bit-for-bit:
|
||||
* rope_theta = 10000
|
||||
* per-head RMSNorm on Q and K
|
||||
* dual-stream (MM-DiT style) joint attention; last dual block is cond_pre_only
|
||||
* two separate RoPE frequency tensors (dual blocks offset cond tokens by S)
|
||||
* SwiGLU MLP, non-affine LayerNorm upcast to fp32
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Norms (faithful to cube3d/model/transformers/norm.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CubeLayerNorm(nn.Module):
|
||||
"""Non-affine LayerNorm that upcasts to fp32 then back (matches cube)."""
|
||||
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = (dim,)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
y = F.layer_norm(x.float(), self.dim, None, None, self.eps)
|
||||
return y.type_as(x)
|
||||
|
||||
|
||||
class CubeRMSNorm(nn.Module):
|
||||
"""Per-head RMSNorm with learnable weight, computed in fp32 (matches cube)."""
|
||||
|
||||
def __init__(self, dim, eps=1e-5, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x):
|
||||
xf = x.float()
|
||||
out = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return (out * self.weight).type_as(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RoPE (faithful to cube3d/model/transformers/rope.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis, curr_pos_id=None):
|
||||
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
if curr_pos_id is None:
|
||||
freqs_cis = freqs_cis[:, -x.shape[2]:].unsqueeze(1)
|
||||
else:
|
||||
freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1)
|
||||
y = torch.view_as_real(x_ * freqs_cis).flatten(3)
|
||||
return y.type_as(x)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim, t, theta=10000.0):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim))
|
||||
freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1)
|
||||
return torch.polar(torch.ones_like(freqs), freqs)
|
||||
|
||||
|
||||
def sdpa_with_rope(q, k, v, freqs_cis, attn_mask=None, curr_pos_id=None, is_causal=False):
|
||||
q = apply_rotary_emb(q, freqs_cis, curr_pos_id=curr_pos_id)
|
||||
k = apply_rotary_emb(k, freqs_cis, curr_pos_id=None)
|
||||
return F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0,
|
||||
is_causal=is_causal and attn_mask is None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KV cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Cache:
|
||||
def __init__(self, key_states, value_states):
|
||||
self.key_states = key_states
|
||||
self.value_states = value_states
|
||||
|
||||
def update(self, curr_pos_id, k, v):
|
||||
self.key_states.index_copy_(2, curr_pos_id, k)
|
||||
self.value_states.index_copy_(2, curr_pos_id, v)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared building blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SwiGLUMLP(nn.Module):
|
||||
def __init__(self, embed_dim, hidden_dim, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.gate_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.up_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.down_proj = operations.Linear(hidden_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class SelfAttentionWithRotaryEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
head_dim = embed_dim // num_heads
|
||||
self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
b, l, d = x.shape
|
||||
q, k = self.c_qk(x).chunk(2, dim=-1)
|
||||
v = self.c_v(x)
|
||||
q = q.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
k = k.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
v = v.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if kv_cache is not None:
|
||||
if not decode:
|
||||
kv_cache.key_states[:, :, :k.shape[2], :].copy_(k)
|
||||
kv_cache.value_states[:, :, :k.shape[2], :].copy_(v)
|
||||
else:
|
||||
kv_cache.update(curr_pos_id, k, v)
|
||||
k = kv_cache.key_states
|
||||
v = kv_cache.value_states
|
||||
y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask,
|
||||
curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal)
|
||||
y = y.transpose(1, 2).contiguous().view(b, l, d)
|
||||
return self.c_proj(y)
|
||||
|
||||
|
||||
class DecoderLayerWithRotaryEmbedding(nn.Module):
|
||||
"""Single-stream decoder layer (shape tokens only)."""
|
||||
|
||||
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.attn = SelfAttentionWithRotaryEmbedding(embed_dim, num_heads, bias=bias, eps=eps,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
x = x + self.attn(self.ln_1(x), freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal,
|
||||
kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dual-stream blocks (faithful to dual_stream_attention.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DismantledPreAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, query=True, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0
|
||||
self.query = query
|
||||
head_dim = embed_dim // num_heads
|
||||
if query:
|
||||
self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.c_k = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def _to_mha(self, x):
|
||||
return x.view(*x.shape[:2], self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
def forward(self, x):
|
||||
if self.query:
|
||||
q, k = self.c_qk(x).chunk(2, dim=-1)
|
||||
q = self.q_norm(self._to_mha(q))
|
||||
else:
|
||||
q = None
|
||||
k = self.c_k(x)
|
||||
k = self.k_norm(self._to_mha(k))
|
||||
v = self._to_mha(self.c_v(x))
|
||||
return (q, k, v)
|
||||
|
||||
|
||||
class DismantledPostAttention(nn.Module):
|
||||
def __init__(self, embed_dim, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.ln_3 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, a):
|
||||
x = x + self.c_proj(a)
|
||||
x = x + self.mlp(self.ln_3(x))
|
||||
return x
|
||||
|
||||
|
||||
class DualStreamAttentionWithRotaryEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.cond_pre_only = cond_pre_only
|
||||
self.pre_x = DismantledPreAttention(embed_dim, num_heads, query=True, bias=bias,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
self.pre_c = DismantledPreAttention(embed_dim, num_heads, query=not cond_pre_only, bias=bias,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
if kv_cache is None or not decode:
|
||||
qkv_c = self.pre_c(c)
|
||||
qkv_x = self.pre_x(x)
|
||||
if self.cond_pre_only:
|
||||
q = qkv_x[0]
|
||||
else:
|
||||
q = torch.cat([qkv_c[0], qkv_x[0]], dim=2)
|
||||
k = torch.cat([qkv_c[1], qkv_x[1]], dim=2)
|
||||
v = torch.cat([qkv_c[2], qkv_x[2]], dim=2)
|
||||
else:
|
||||
is_causal = False
|
||||
q, k, v = self.pre_x(x)
|
||||
|
||||
if kv_cache is not None:
|
||||
if not decode:
|
||||
kv_cache.key_states[:, :, :k.shape[2], :].copy_(k)
|
||||
kv_cache.value_states[:, :, :k.shape[2], :].copy_(v)
|
||||
else:
|
||||
kv_cache.update(curr_pos_id, k, v)
|
||||
k = kv_cache.key_states
|
||||
v = kv_cache.value_states
|
||||
|
||||
if attn_mask is not None:
|
||||
if decode:
|
||||
attn_mask = attn_mask[..., curr_pos_id, :]
|
||||
else:
|
||||
attn_mask = attn_mask[..., -q.shape[2]:, :]
|
||||
|
||||
y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask,
|
||||
curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal)
|
||||
y = y.transpose(1, 2).contiguous().view(x.shape[0], -1, x.shape[2])
|
||||
|
||||
if y.shape[1] == x.shape[1]:
|
||||
return y, None
|
||||
y_c, y_x = torch.split(y, [c.shape[1], x.shape[1]], dim=1)
|
||||
return y_x, y_c
|
||||
|
||||
|
||||
class DualStreamDecoderLayerWithRotaryEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, eps=1e-6,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.attn = DualStreamAttentionWithRotaryEmbedding(embed_dim, num_heads, cond_pre_only=cond_pre_only,
|
||||
bias=bias, dtype=dtype, device=device, operations=operations)
|
||||
self.post_1 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations)
|
||||
if not cond_pre_only:
|
||||
self.post_2 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
a_x, a_c = self.attn(
|
||||
self.ln_1(x),
|
||||
self.ln_2(c) if c is not None else None,
|
||||
freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal,
|
||||
kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode,
|
||||
)
|
||||
x = self.post_1(x, a_x)
|
||||
if a_c is not None:
|
||||
c = self.post_2(c, a_c)
|
||||
else:
|
||||
c = None
|
||||
return x, c
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DualStreamRoformer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DualStreamRoformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_layer=23,
|
||||
n_single_layer=1,
|
||||
rope_theta=10000,
|
||||
n_head=12,
|
||||
n_embd=1536,
|
||||
bias=True,
|
||||
eps=1e-6,
|
||||
shape_model_vocab_size=16384,
|
||||
shape_model_embed_dim=32,
|
||||
text_model_embed_dim=768,
|
||||
use_bbox=True,
|
||||
image_model=None, # detection key; unused
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.n_layer = n_layer
|
||||
self.n_single_layer = n_single_layer
|
||||
self.n_head = n_head
|
||||
self.n_embd = n_embd
|
||||
self.rope_theta = rope_theta
|
||||
self.head_dim = n_embd // n_head
|
||||
|
||||
self.text_proj = operations.Linear(text_model_embed_dim, n_embd, bias=bias, dtype=dtype, device=device)
|
||||
self.shape_proj = operations.Linear(shape_model_embed_dim, n_embd, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.vocab_size = shape_model_vocab_size
|
||||
self.shape_bos_id = self.vocab_size
|
||||
self.shape_eos_id = self.vocab_size + 1
|
||||
self.padding_id = self.vocab_size + 2
|
||||
self.vocab_size += 3
|
||||
|
||||
self.transformer = nn.ModuleDict(dict(
|
||||
wte=operations.Embedding(self.vocab_size, n_embd, padding_idx=self.padding_id, dtype=dtype, device=device),
|
||||
dual_blocks=nn.ModuleList([
|
||||
DualStreamDecoderLayerWithRotaryEmbedding(
|
||||
n_embd, n_head, cond_pre_only=(i == n_layer - 1), bias=bias, eps=eps,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]),
|
||||
single_blocks=nn.ModuleList([
|
||||
DecoderLayerWithRotaryEmbedding(n_embd, n_head, bias=bias, eps=eps,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(n_single_layer)
|
||||
]),
|
||||
ln_f=CubeLayerNorm(n_embd, eps=eps),
|
||||
))
|
||||
|
||||
self.lm_head = operations.Linear(n_embd, self.vocab_size, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.use_bbox = use_bbox
|
||||
if use_bbox:
|
||||
self.bbox_proj = operations.Linear(3, n_embd, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def encode_text(self, text_embed):
|
||||
return self.text_proj(text_embed)
|
||||
|
||||
def encode_token(self, tokens):
|
||||
return self.transformer.wte(tokens)
|
||||
|
||||
def init_kv_cache(self, batch_size, cond_len, max_shape_tokens, dtype, device):
|
||||
max_all = cond_len + max_shape_tokens
|
||||
kv = [
|
||||
Cache(
|
||||
torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device),
|
||||
torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device),
|
||||
)
|
||||
for _ in range(len(self.transformer.dual_blocks))
|
||||
]
|
||||
kv += [
|
||||
Cache(
|
||||
torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device),
|
||||
torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device),
|
||||
)
|
||||
for _ in range(len(self.transformer.single_blocks))
|
||||
]
|
||||
return kv
|
||||
|
||||
def forward(self, embed, cond, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
b, l = embed.shape[:2]
|
||||
s = cond.shape[1]
|
||||
device = embed.device
|
||||
|
||||
attn_mask = torch.tril(torch.ones(s + l, s + l, dtype=torch.bool, device=device))
|
||||
|
||||
position_ids = torch.arange(l, dtype=torch.long, device=device).unsqueeze(0).expand(b, -1)
|
||||
s_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta)
|
||||
|
||||
position_ids = torch.cat([
|
||||
torch.zeros([b, s], dtype=torch.long, device=device),
|
||||
position_ids,
|
||||
], dim=1)
|
||||
d_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta)
|
||||
|
||||
if kv_cache is not None and decode:
|
||||
embed = embed[:, curr_pos_id, :]
|
||||
|
||||
h = embed
|
||||
c = cond
|
||||
layer_idx = 0
|
||||
for block in self.transformer.dual_blocks:
|
||||
h, c = block(
|
||||
h, c=c, freqs_cis=d_freqs_cis, attn_mask=attn_mask, is_causal=True,
|
||||
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
||||
curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None,
|
||||
decode=decode,
|
||||
)
|
||||
layer_idx += 1
|
||||
for block in self.transformer.single_blocks:
|
||||
h = block(
|
||||
h, freqs_cis=s_freqs_cis, attn_mask=None, is_causal=True,
|
||||
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
||||
curr_pos_id=curr_pos_id, decode=decode,
|
||||
)
|
||||
layer_idx += 1
|
||||
|
||||
h = self.transformer.ln_f(h)
|
||||
return self.lm_head(h)
|
||||
379
comfy/ldm/cube/marching_cubes.py
Normal file
379
comfy/ldm/cube/marching_cubes.py
Normal file
@ -0,0 +1,379 @@
|
||||
"""Dependency-free marching cubes (classic Lorensen/Cline) in pure PyTorch.
|
||||
|
||||
Vendored so Cube3D mesh extraction needs no scikit-image. This is the same
|
||||
algorithm family as upstream cube's default NVIDIA-warp backend (warp.MarchingCubes),
|
||||
so geometry is closer to the upstream default than skimage's Lewiner fallback.
|
||||
|
||||
Output convention matches skimage.measure.marching_cubes: vertices are returned in
|
||||
array-index coordinates (axis 0, axis 1, axis 2 of the input volume), so the caller's
|
||||
`vertices / grid_size * bbox_size + bbox_min` transform applies unchanged.
|
||||
|
||||
The standard 256-entry triangle table (Paul Bourke / Cory Bloyd) is used with the
|
||||
canonical corner and edge numbering:
|
||||
|
||||
corners (x, y, z): edges (corner pairs):
|
||||
0: (0,0,0) 1: (1,0,0) 0:0-1 1:1-2 2:2-3 3:3-0
|
||||
2: (1,1,0) 3: (0,1,0) 4:4-5 5:5-6 6:6-7 7:7-4
|
||||
4: (0,0,1) 5: (1,0,1) 8:0-4 9:1-5 10:2-6 11:3-7
|
||||
6: (1,1,1) 7: (0,1,1)
|
||||
|
||||
Here x maps to volume axis 0, y to axis 1, z to axis 2.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Corner offsets in (axis0, axis1, axis2) for the 8 cube corners.
|
||||
_CORNERS = np.array([
|
||||
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0],
|
||||
[0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1],
|
||||
], dtype=np.int64)
|
||||
|
||||
# The two corner indices that each of the 12 edges connects.
|
||||
_EDGE_CORNERS = np.array([
|
||||
[0, 1], [1, 2], [2, 3], [3, 0],
|
||||
[4, 5], [5, 6], [6, 7], [7, 4],
|
||||
[0, 4], [1, 5], [2, 6], [3, 7],
|
||||
], dtype=np.int64)
|
||||
|
||||
# Standard 256 x 16 triangle table. For cube configuration `i`, lists triples of
|
||||
# edge indices forming triangles, terminated by -1.
|
||||
_TRI_TABLE = [
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 8, 3, 9, 8, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 3, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 2, 10, 0, 2, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 8, 3, 2, 10, 8, 10, 9, 8, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 11, 2, 8, 11, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 9, 0, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 11, 2, 1, 9, 11, 9, 8, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 10, 1, 11, 10, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 10, 1, 0, 8, 10, 8, 11, 10, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 9, 0, 3, 11, 9, 11, 10, 9, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 3, 0, 7, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 1, 9, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 1, 9, 4, 7, 1, 7, 3, 1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 10, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 4, 7, 3, 0, 4, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 2, 10, 9, 0, 2, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4, -1, -1, -1, -1],
|
||||
[8, 4, 7, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[11, 4, 7, 11, 2, 4, 2, 0, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 0, 1, 8, 4, 7, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1, -1, -1, -1, -1],
|
||||
[3, 10, 1, 3, 11, 10, 7, 8, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4, -1, -1, -1, -1],
|
||||
[4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3, -1, -1, -1, -1],
|
||||
[4, 7, 11, 4, 11, 9, 9, 11, 10, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 5, 4, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 5, 4, 1, 5, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 5, 4, 8, 3, 5, 3, 1, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 10, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 0, 8, 1, 2, 10, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 2, 10, 5, 4, 2, 4, 0, 2, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8, -1, -1, -1, -1],
|
||||
[9, 5, 4, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 11, 2, 0, 8, 11, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 5, 4, 0, 1, 5, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5, -1, -1, -1, -1],
|
||||
[10, 3, 11, 10, 1, 3, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10, -1, -1, -1, -1],
|
||||
[5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3, -1, -1, -1, -1],
|
||||
[5, 4, 8, 5, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 7, 8, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 3, 0, 9, 5, 3, 5, 7, 3, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 7, 8, 0, 1, 7, 1, 5, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 7, 8, 9, 5, 7, 10, 1, 2, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3, -1, -1, -1, -1],
|
||||
[8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2, -1, -1, -1, -1],
|
||||
[2, 10, 5, 2, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[7, 9, 5, 7, 8, 9, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11, -1, -1, -1, -1],
|
||||
[2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7, -1, -1, -1, -1],
|
||||
[11, 2, 1, 11, 1, 7, 7, 1, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11, -1, -1, -1, -1],
|
||||
[5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0, -1],
|
||||
[11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0, -1],
|
||||
[11, 10, 5, 7, 11, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 3, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 0, 1, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 8, 3, 1, 9, 8, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 6, 5, 2, 6, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 6, 5, 1, 2, 6, 3, 0, 8, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 6, 5, 9, 0, 6, 0, 2, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8, -1, -1, -1, -1],
|
||||
[2, 3, 11, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[11, 0, 8, 11, 2, 0, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 1, 9, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11, -1, -1, -1, -1],
|
||||
[6, 3, 11, 6, 5, 3, 5, 1, 3, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6, -1, -1, -1, -1],
|
||||
[3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9, -1, -1, -1, -1],
|
||||
[6, 5, 9, 6, 9, 11, 11, 9, 8, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 10, 6, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 3, 0, 4, 7, 3, 6, 5, 10, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 9, 0, 5, 10, 6, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4, -1, -1, -1, -1],
|
||||
[6, 1, 2, 6, 5, 1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7, -1, -1, -1, -1],
|
||||
[8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6, -1, -1, -1, -1],
|
||||
[7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9, -1],
|
||||
[3, 11, 2, 7, 8, 4, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11, -1, -1, -1, -1],
|
||||
[0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1],
|
||||
[9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6, -1],
|
||||
[8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6, -1, -1, -1, -1],
|
||||
[5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11, -1],
|
||||
[0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7, -1],
|
||||
[6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9, -1, -1, -1, -1],
|
||||
[10, 4, 9, 6, 4, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 10, 6, 4, 9, 10, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 0, 1, 10, 6, 0, 6, 4, 0, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10, -1, -1, -1, -1],
|
||||
[1, 4, 9, 1, 2, 4, 2, 6, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4, -1, -1, -1, -1],
|
||||
[0, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 3, 2, 8, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 4, 9, 10, 6, 4, 11, 2, 3, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6, -1, -1, -1, -1],
|
||||
[3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10, -1, -1, -1, -1],
|
||||
[6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1, -1],
|
||||
[9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3, -1, -1, -1, -1],
|
||||
[8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1, -1],
|
||||
[3, 11, 6, 3, 6, 0, 0, 6, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||
[6, 4, 8, 11, 6, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[7, 10, 6, 7, 8, 10, 8, 9, 10, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10, -1, -1, -1, -1],
|
||||
[10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0, -1, -1, -1, -1],
|
||||
[10, 6, 7, 10, 7, 1, 1, 7, 3, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7, -1, -1, -1, -1],
|
||||
[2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9, -1],
|
||||
[7, 8, 0, 7, 0, 6, 6, 0, 2, -1, -1, -1, -1, -1, -1, -1],
|
||||
[7, 3, 2, 6, 7, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7, -1, -1, -1, -1],
|
||||
[2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7, -1],
|
||||
[1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11, -1],
|
||||
[11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1, -1, -1, -1, -1],
|
||||
[8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6, -1],
|
||||
[0, 9, 1, 11, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0, -1, -1, -1, -1],
|
||||
[7, 11, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 0, 8, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 1, 9, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 1, 9, 8, 3, 1, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 1, 2, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 10, 3, 0, 8, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 9, 0, 2, 10, 9, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8, -1, -1, -1, -1],
|
||||
[7, 2, 3, 6, 2, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[7, 0, 8, 7, 6, 0, 6, 2, 0, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 7, 6, 2, 3, 7, 0, 1, 9, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6, -1, -1, -1, -1],
|
||||
[10, 7, 6, 10, 1, 7, 1, 3, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8, -1, -1, -1, -1],
|
||||
[0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7, -1, -1, -1, -1],
|
||||
[7, 6, 10, 7, 10, 8, 8, 10, 9, -1, -1, -1, -1, -1, -1, -1],
|
||||
[6, 8, 4, 11, 8, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 6, 11, 3, 0, 6, 0, 4, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 6, 11, 8, 4, 6, 9, 0, 1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6, -1, -1, -1, -1],
|
||||
[6, 8, 4, 6, 11, 8, 2, 10, 1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6, -1, -1, -1, -1],
|
||||
[4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9, -1, -1, -1, -1],
|
||||
[10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3, -1],
|
||||
[8, 2, 3, 8, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8, -1, -1, -1, -1],
|
||||
[1, 9, 4, 1, 4, 2, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1, -1, -1, -1, -1],
|
||||
[10, 1, 0, 10, 0, 6, 6, 0, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3, -1],
|
||||
[10, 9, 4, 6, 10, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 9, 5, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 3, 4, 9, 5, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 0, 1, 5, 4, 0, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5, -1, -1, -1, -1],
|
||||
[9, 5, 4, 10, 1, 2, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5, -1, -1, -1, -1],
|
||||
[7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2, -1, -1, -1, -1],
|
||||
[3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6, -1],
|
||||
[7, 2, 3, 7, 6, 2, 5, 4, 9, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7, -1, -1, -1, -1],
|
||||
[3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0, -1, -1, -1, -1],
|
||||
[6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8, -1],
|
||||
[9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7, -1, -1, -1, -1],
|
||||
[1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4, -1],
|
||||
[4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10, -1],
|
||||
[7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10, -1, -1, -1, -1],
|
||||
[6, 9, 5, 6, 11, 9, 11, 8, 9, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5, -1, -1, -1, -1],
|
||||
[0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11, -1, -1, -1, -1],
|
||||
[6, 11, 3, 6, 3, 5, 5, 3, 1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6, -1, -1, -1, -1],
|
||||
[0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10, -1],
|
||||
[11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5, -1],
|
||||
[6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3, -1, -1, -1, -1],
|
||||
[5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2, -1, -1, -1, -1],
|
||||
[9, 5, 6, 9, 6, 0, 0, 6, 2, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8, -1],
|
||||
[1, 5, 6, 2, 1, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6, -1],
|
||||
[10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0, -1, -1, -1, -1],
|
||||
[0, 3, 8, 5, 6, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[11, 5, 10, 7, 5, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[11, 5, 10, 11, 7, 5, 8, 3, 0, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 11, 7, 5, 10, 11, 1, 9, 0, -1, -1, -1, -1, -1, -1, -1],
|
||||
[10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1, -1, -1, -1, -1],
|
||||
[11, 1, 2, 11, 7, 1, 7, 5, 1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11, -1, -1, -1, -1],
|
||||
[9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7, -1, -1, -1, -1],
|
||||
[7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2, -1],
|
||||
[2, 5, 10, 2, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5, -1, -1, -1, -1],
|
||||
[9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2, -1, -1, -1, -1],
|
||||
[9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2, -1],
|
||||
[1, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 7, 0, 7, 1, 1, 7, 5, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 0, 3, 9, 3, 5, 5, 3, 7, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 8, 7, 5, 9, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 8, 4, 5, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1],
|
||||
[5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0, -1, -1, -1, -1],
|
||||
[0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5, -1, -1, -1, -1],
|
||||
[10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4, -1],
|
||||
[2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8, -1, -1, -1, -1],
|
||||
[0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11, -1],
|
||||
[0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5, -1],
|
||||
[9, 4, 5, 2, 11, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4, -1, -1, -1, -1],
|
||||
[5, 10, 2, 5, 2, 4, 4, 2, 0, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9, -1],
|
||||
[5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2, -1, -1, -1, -1],
|
||||
[8, 4, 5, 8, 5, 3, 3, 5, 1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 4, 5, 1, 0, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5, -1, -1, -1, -1],
|
||||
[9, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 11, 7, 4, 9, 11, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11, -1, -1, -1, -1],
|
||||
[1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11, -1, -1, -1, -1],
|
||||
[3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4, -1],
|
||||
[4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2, -1, -1, -1, -1],
|
||||
[9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3, -1],
|
||||
[11, 7, 4, 11, 4, 2, 2, 4, 0, -1, -1, -1, -1, -1, -1, -1],
|
||||
[11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4, -1, -1, -1, -1],
|
||||
[2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9, -1, -1, -1, -1],
|
||||
[9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7, -1],
|
||||
[3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10, -1],
|
||||
[1, 10, 2, 8, 7, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 9, 1, 4, 1, 7, 7, 1, 3, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1, -1, -1, -1, -1],
|
||||
[4, 0, 3, 7, 4, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 8, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 0, 9, 3, 9, 11, 11, 9, 10, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 1, 10, 0, 10, 8, 8, 10, 11, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 1, 10, 11, 3, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 2, 11, 1, 11, 9, 9, 11, 8, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9, -1, -1, -1, -1],
|
||||
[0, 2, 11, 8, 0, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 3, 8, 2, 8, 10, 10, 8, 9, -1, -1, -1, -1, -1, -1, -1],
|
||||
[9, 10, 2, 0, 9, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8, -1, -1, -1, -1],
|
||||
[1, 10, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[1, 3, 8, 9, 1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 9, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[0, 3, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def marching_cubes(volume: torch.Tensor, level: float = 0.0):
|
||||
"""Extract an isosurface from a 3D scalar field.
|
||||
|
||||
Args:
|
||||
volume: (D, H, W) float tensor. Inside is where ``volume < level`` (matches
|
||||
the classic Lorensen convention and skimage's ``method="lorensen"``).
|
||||
level: isosurface threshold.
|
||||
|
||||
Returns:
|
||||
(vertices, faces): numpy arrays. ``vertices`` are float32 (N, 3) in array-index
|
||||
coordinates (axis0, axis1, axis2); ``faces`` are int64 (M, 3).
|
||||
"""
|
||||
assert volume.ndim == 3, "volume must be (D, H, W)"
|
||||
device = volume.device
|
||||
vol = volume.float()
|
||||
|
||||
tri_table = torch.tensor(_TRI_TABLE, dtype=torch.long, device=device) # (256, 16)
|
||||
edge_corners = torch.tensor(_EDGE_CORNERS, dtype=torch.long, device=device) # (12, 2)
|
||||
corners = torch.tensor(_CORNERS, dtype=torch.float32, device=device) # (8, 3)
|
||||
|
||||
# Corner scalar values for every cell, shape (nc0, nc1, nc2, 8).
|
||||
nc0, nc1, nc2 = vol.shape[0] - 1, vol.shape[1] - 1, vol.shape[2] - 1
|
||||
if nc0 <= 0 or nc1 <= 0 or nc2 <= 0:
|
||||
return (np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.int64))
|
||||
|
||||
corner_vals = torch.empty((nc0, nc1, nc2, 8), dtype=torch.float32, device=device)
|
||||
for k in range(8):
|
||||
o0, o1, o2 = _CORNERS[k]
|
||||
corner_vals[..., k] = vol[o0:o0 + nc0, o1:o1 + nc1, o2:o2 + nc2]
|
||||
|
||||
# Cube configuration index: bit k set when corner k is inside (val < level).
|
||||
inside = (corner_vals < level)
|
||||
bits = torch.tensor([1 << k for k in range(8)], dtype=torch.long, device=device)
|
||||
cube_index = (inside.long() * bits).sum(dim=-1) # (nc0, nc1, nc2)
|
||||
|
||||
# Cells that actually intersect the surface.
|
||||
active = (cube_index > 0) & (cube_index < 255)
|
||||
if not active.any():
|
||||
return (np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.int64))
|
||||
|
||||
idx0, idx1, idx2 = torch.where(active) # (Nactive,)
|
||||
cidx = cube_index[idx0, idx1, idx2] # (Nactive,)
|
||||
cell_origin = torch.stack([idx0, idx1, idx2], dim=1).float() # (Nactive, 3)
|
||||
cell_vals = corner_vals[idx0, idx1, idx2] # (Nactive, 8)
|
||||
|
||||
tris = tri_table[cidx] # (Nactive, 16)
|
||||
|
||||
# Each row holds up to 5 triangles (15 edge entries). Expand to (Nactive, 5, 3).
|
||||
tri_edges = tris[:, :15].reshape(-1, 5, 3) # edge indices, -1 = unused
|
||||
valid_tri = tri_edges[..., 0] >= 0 # (Nactive, 5)
|
||||
|
||||
cell_idx = torch.arange(cell_origin.shape[0], device=device).unsqueeze(1).expand(-1, 5)
|
||||
cell_idx = cell_idx[valid_tri] # (T,)
|
||||
edges = tri_edges[valid_tri] # (T, 3) edge index per triangle corner
|
||||
|
||||
# Interpolate a vertex on each referenced edge.
|
||||
e_flat = edges.reshape(-1) # (T*3,)
|
||||
cell_for_vert = cell_idx.unsqueeze(1).expand(-1, 3).reshape(-1) # (T*3,)
|
||||
|
||||
ca = edge_corners[e_flat, 0] # (T*3,) corner index a
|
||||
cb = edge_corners[e_flat, 1] # corner index b
|
||||
va = cell_vals[cell_for_vert, ca] # scalar at corner a
|
||||
vb = cell_vals[cell_for_vert, cb]
|
||||
pa = cell_origin[cell_for_vert] + corners[ca] # position of corner a (index space)
|
||||
pb = cell_origin[cell_for_vert] + corners[cb]
|
||||
|
||||
denom = (vb - va)
|
||||
t = torch.where(denom.abs() > 1e-12, (level - va) / denom, torch.zeros_like(denom))
|
||||
t = t.clamp(0.0, 1.0).unsqueeze(1)
|
||||
verts = pa + t * (pb - pa) # (T*3, 3) one vertex per triangle corner
|
||||
|
||||
# Weld shared vertices: a grid edge shared by adjacent cells interpolates to the exact
|
||||
# same position (same corner values/positions), so exact dedup yields a clean indexed
|
||||
# mesh like skimage/warp (one vertex per active edge).
|
||||
uniq, inverse = torch.unique(verts, dim=0, return_inverse=True)
|
||||
faces = inverse.reshape(-1, 3)
|
||||
|
||||
return (uniq.cpu().numpy().astype(np.float32), faces.cpu().numpy().astype(np.int64))
|
||||
364
comfy/ldm/cube/vae.py
Normal file
364
comfy/ldm/cube/vae.py
Normal file
@ -0,0 +1,364 @@
|
||||
"""
|
||||
Native port of Roblox/cube's shape tokenizer decode path (OneDAutoEncoder).
|
||||
|
||||
Reference: https://github.com/Roblox/cube (cube3d/model/autoencoder/*).
|
||||
|
||||
Only the DECODE path is ported (token IDs -> latents -> occupancy grid -> mesh);
|
||||
the point-cloud encoder is not needed for text-to-3D generation. Encoder weights in
|
||||
the checkpoint are loaded with strict=False and ignored.
|
||||
|
||||
Module/parameter names mirror upstream so the checkpoint loads directly:
|
||||
embedder.weight
|
||||
bottleneck.block.{codebook, cb_weight, cb_bias, c_in, c_x, c_out, ...}
|
||||
decoder.{positional_encodings, blocks.N...}
|
||||
occupancy_decoder.{query_in, attn_out, ln_f, c_head}
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Norms
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CubeLayerNorm(nn.Module):
|
||||
"""LayerNorm upcasting to fp32. affine=False by default (no params)."""
|
||||
|
||||
def __init__(self, dim, eps=1e-6, elementwise_affine=False, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.dim = (dim,)
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
|
||||
self.bias = nn.Parameter(torch.zeros(dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x):
|
||||
w = self.weight.float() if self.weight is not None else None
|
||||
b = self.bias.float() if self.bias is not None else None
|
||||
y = F.layer_norm(x.float(), self.dim, w, b, self.eps)
|
||||
return y.type_as(x)
|
||||
|
||||
|
||||
class CubeRMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5, elementwise_affine=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.register_buffer("weight", torch.ones(dim), persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
xf = x.float()
|
||||
out = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return (out * self.weight.float()).type_as(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fourier embedder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PhaseModulatedFourierEmbedder(nn.Module):
|
||||
def __init__(self, num_freqs, input_dim=3, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty(input_dim, num_freqs, dtype=dtype, device=device))
|
||||
carrier = (num_freqs / 8) ** torch.linspace(1, 0, num_freqs)
|
||||
carrier = (carrier + torch.linspace(0, 1, num_freqs)) * 2 * math.pi
|
||||
self.register_buffer("carrier", carrier, persistent=False)
|
||||
self.out_dim = input_dim * (num_freqs * 2 + 1)
|
||||
|
||||
def forward(self, x):
|
||||
m = x.float().unsqueeze(-1)
|
||||
w = self.weight.float()
|
||||
carrier = self.carrier.float()
|
||||
fm = (m * w).view(*x.shape[:-1], -1)
|
||||
pm = (m * 0.5 * math.pi + carrier).view(*x.shape[:-1], -1)
|
||||
return torch.cat([x, fm.cos() + pm.cos(), fm.sin() + pm.sin()], dim=-1).type_as(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attention building blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, embed_dim, hidden_dim, bias=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.up_proj = ops.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.down_proj = ops.Linear(hidden_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.act_fn = nn.GELU(approximate="none")
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
head_dim = embed_dim // num_heads
|
||||
self.c_qk = ops.Linear(embed_dim, 2 * embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_v = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_proj = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, attn_mask=None, is_causal=False):
|
||||
b, l, d = x.shape
|
||||
q, k = self.c_qk(x).chunk(2, dim=-1)
|
||||
v = self.c_v(x)
|
||||
q = self.q_norm(q.view(b, l, self.num_heads, -1).transpose(1, 2))
|
||||
k = self.k_norm(k.view(b, l, self.num_heads, -1).transpose(1, 2))
|
||||
v = v.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0,
|
||||
is_causal=is_causal and attn_mask is None)
|
||||
y = y.transpose(1, 2).contiguous().view(b, l, d)
|
||||
return self.c_proj(y)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, q_dim=None, kv_dim=None, bias=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0
|
||||
q_dim = q_dim or embed_dim
|
||||
kv_dim = kv_dim or embed_dim
|
||||
self.c_q = ops.Linear(q_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_k = ops.Linear(kv_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_v = ops.Linear(kv_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_proj = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
q, k, v = self.c_q(x), self.c_k(c), self.c_v(c)
|
||||
b, l, d = q.shape
|
||||
s = k.shape[1]
|
||||
q = q.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
k = k.view(b, s, self.num_heads, -1).transpose(1, 2)
|
||||
v = v.view(b, s, self.num_heads, -1).transpose(1, 2)
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
|
||||
y = y.transpose(1, 2).contiguous().view(b, l, d)
|
||||
return self.c_proj(y)
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.attn = SelfAttention(embed_dim, num_heads, bias=bias, eps=eps, dtype=dtype, device=device)
|
||||
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.mlp = MLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, attn_mask=None, is_causal=False):
|
||||
x = x + self.attn(self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class EncoderCrossAttentionLayer(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, q_dim=None, kv_dim=None, bias=True, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
q_dim = q_dim or embed_dim
|
||||
kv_dim = kv_dim or embed_dim
|
||||
self.attn = CrossAttention(embed_dim, num_heads, q_dim=q_dim, kv_dim=kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.ln_1 = CubeLayerNorm(q_dim, eps=eps)
|
||||
self.ln_2 = CubeLayerNorm(kv_dim, eps=eps)
|
||||
self.ln_f = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.mlp = MLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
x = x + self.attn(self.ln_1(x), self.ln_2(c), attn_mask=attn_mask)
|
||||
x = x + self.mlp(self.ln_f(x))
|
||||
return x
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim, embed_dim, bias=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.in_layer = ops.Linear(in_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Spherical VQ (decode-only parts)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SphericalVectorQuantizer(nn.Module):
|
||||
def __init__(self, embed_dim, num_codes, width=None, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.num_codes = num_codes
|
||||
self.codebook = ops.Embedding(num_codes, embed_dim, dtype=dtype, device=device)
|
||||
width = width or embed_dim
|
||||
if width != embed_dim:
|
||||
self.c_in = ops.Linear(width, embed_dim, dtype=dtype, device=device)
|
||||
self.c_x = ops.Linear(width, embed_dim, dtype=dtype, device=device)
|
||||
self.c_out = ops.Linear(embed_dim, width, dtype=dtype, device=device)
|
||||
else:
|
||||
self.c_in = self.c_out = self.c_x = nn.Identity()
|
||||
self.norm = CubeRMSNorm(embed_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
# "kl" codebook regularization (released config)
|
||||
self.cb_weight = nn.Parameter(torch.ones([embed_dim], dtype=dtype, device=device))
|
||||
self.cb_bias = nn.Parameter(torch.zeros([embed_dim], dtype=dtype, device=device))
|
||||
|
||||
def cb_norm(self, x):
|
||||
return x * self.cb_weight + self.cb_bias
|
||||
|
||||
def get_codebook(self):
|
||||
return self.norm(self.cb_norm(self.codebook.weight))
|
||||
|
||||
def lookup_codebook(self, q):
|
||||
z_q = F.embedding(q, self.get_codebook())
|
||||
return self.c_out(z_q)
|
||||
|
||||
|
||||
class OneDBottleNeck(nn.Module):
|
||||
def __init__(self, block):
|
||||
super().__init__()
|
||||
self.block = block
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decoders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OneDDecoder(nn.Module):
|
||||
def __init__(self, num_latents, width, num_heads, num_layers, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.register_buffer("query", torch.empty([0, width]), persistent=False)
|
||||
self.positional_encodings = nn.Parameter(torch.empty(num_latents, width, dtype=dtype, device=device))
|
||||
self.blocks = nn.ModuleList([
|
||||
EncoderLayer(width, num_heads, eps=eps, dtype=dtype, device=device)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, z):
|
||||
h = z + self.positional_encodings[:z.shape[1]].unsqueeze(0).to(z.dtype)
|
||||
for block in self.blocks:
|
||||
h = block(h)
|
||||
return h
|
||||
|
||||
|
||||
class OneDOccupancyDecoder(nn.Module):
|
||||
def __init__(self, embedder, out_features, width, num_heads, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.embedder = embedder
|
||||
self.query_in = MLPEmbedder(embedder.out_dim, width, dtype=dtype, device=device)
|
||||
self.attn_out = EncoderCrossAttentionLayer(width, num_heads, dtype=dtype, device=device)
|
||||
self.ln_f = CubeLayerNorm(width, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||
self.c_head = ops.Linear(width, out_features, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, queries, latents):
|
||||
x = self.query_in(self.embedder(queries))
|
||||
x = self.attn_out(x, latents)
|
||||
return self.c_head(self.ln_f(x))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level shape VAE
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij"):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = np.exp2(resolution_base)
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
xs, ys, zs = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1).reshape(-1, 3)
|
||||
grid_size = [int(num_cells) + 1] * 3
|
||||
return xyz, grid_size, length
|
||||
|
||||
|
||||
class CubeShapeVAE(nn.Module):
|
||||
"""Decode-only OneDAutoEncoder. Encoder weights load with strict=False (ignored)."""
|
||||
|
||||
# Fixed query bounds for the occupancy grid (upstream default).
|
||||
decode_bounds = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05)
|
||||
|
||||
def __init__(self, num_encoder_latents=1024, embed_dim=32, width=768, num_heads=12,
|
||||
num_freqs=128, num_decoder_layers=24, num_codes=16384, out_dim=1, eps=1e-6,
|
||||
dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.cfg_num_encoder_latents = num_encoder_latents
|
||||
self.cfg_num_codes = num_codes
|
||||
self.embedder = PhaseModulatedFourierEmbedder(num_freqs=num_freqs, input_dim=3, dtype=dtype, device=device)
|
||||
self.bottleneck = OneDBottleNeck(
|
||||
SphericalVectorQuantizer(embed_dim, num_codes, width, dtype=dtype, device=device)
|
||||
)
|
||||
self.decoder = OneDDecoder(num_encoder_latents, width, num_heads, num_decoder_layers,
|
||||
eps=eps, dtype=dtype, device=device)
|
||||
self.occupancy_decoder = OneDOccupancyDecoder(self.embedder, out_dim, width, num_heads,
|
||||
eps=eps, dtype=dtype, device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, samples, resolution_base=8.0, chunk_size=100_000, **kwargs):
|
||||
"""Token IDs -> occupancy grid logits. Entry point for comfy.sd.VAE.decode, which
|
||||
manages model loading/device/dtype. `samples` arrive as (B, 1, num_tokens) in the
|
||||
VAE working dtype on the load device. VAE.decode applies a trailing movedim(1, -1),
|
||||
so pre-invert it here to hand the node grid logits as (B, gx, gy, gz)."""
|
||||
ids = samples.reshape(samples.shape[0], -1)[:, :self.cfg_num_encoder_latents]
|
||||
ids = ids.round().long().clamp(0, self.cfg_num_codes - 1)
|
||||
latents = self.decode_indices(ids)
|
||||
grid_logits, _, _, _ = self.extract_geometry(
|
||||
latents, bounds=self.decode_bounds, resolution_base=resolution_base, chunk_size=chunk_size)
|
||||
return grid_logits.movedim(-1, 1)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_indices(self, shape_ids):
|
||||
z_q = self.bottleneck.block.lookup_codebook(shape_ids)
|
||||
return self.decoder(z_q)
|
||||
|
||||
@torch.no_grad()
|
||||
def query(self, queries, latents):
|
||||
return self.occupancy_decoder(queries, latents).squeeze(-1)
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_geometry(self, latents, bounds=(-1.05, -1.05, -1.05, 1.05, 1.05, 1.05),
|
||||
resolution_base=8.0, chunk_size=100_000):
|
||||
bbox_min = np.array(bounds[0:3])
|
||||
bbox_max = np.array(bounds[3:6])
|
||||
bbox_size = bbox_max - bbox_min
|
||||
|
||||
xyz, grid_size, _ = generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij")
|
||||
xyz = torch.from_numpy(xyz)
|
||||
batch_size = latents.shape[0]
|
||||
batch_logits = []
|
||||
for start in range(0, xyz.shape[0], chunk_size):
|
||||
queries = xyz[start:start + chunk_size, :]
|
||||
n = queries.shape[0]
|
||||
if start > 0 and n < chunk_size:
|
||||
queries = F.pad(queries, [0, 0, 0, chunk_size - n])
|
||||
bq = queries.unsqueeze(0).expand(batch_size, -1, -1).to(latents)
|
||||
batch_logits.append(self.query(bq, latents)[:, :n])
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1).detach().view(
|
||||
batch_size, grid_size[0], grid_size[1], grid_size[2]).float()
|
||||
return grid_logits, grid_size, bbox_size, bbox_min
|
||||
|
||||
|
||||
def grid_logits_to_mesh(grid_logit, grid_size, bbox_size, bbox_min, level=0.0):
|
||||
"""Occupancy-logit grid -> mesh, using the vendored dependency-free marching cubes
|
||||
(classic Lorensen, same family as upstream cube's default warp backend). Vertices are
|
||||
rescaled from grid-index space into the bbox, matching upstream's transform."""
|
||||
from comfy.ldm.cube.marching_cubes import marching_cubes
|
||||
vertices, faces = marching_cubes(grid_logit, level)
|
||||
vertices = vertices / np.array(grid_size) * bbox_size + bbox_min
|
||||
# The vendored Lorensen table already emits outward-facing winding for this
|
||||
# occupancy convention, so (unlike the upstream skimage path) no face flip is needed.
|
||||
return vertices.astype(np.float32), np.ascontiguousarray(faces)
|
||||
@ -106,11 +106,11 @@ class Ideogram4EmbedScalar(nn.Module):
|
||||
self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||
self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, dtype):
|
||||
x = x.to(torch.float32)
|
||||
scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min)
|
||||
emb = _sinusoidal_embedding(scaled, self.dim)
|
||||
emb = emb.to(self.mlp_in.weight.dtype)
|
||||
emb = emb.to(dtype)
|
||||
emb = F.silu(self.mlp_in(emb))
|
||||
return self.mlp_out(emb)
|
||||
|
||||
@ -161,7 +161,7 @@ class Ideogram4Transformer(nn.Module):
|
||||
x = x * output_image_mask
|
||||
h = self.input_proj(x) * output_image_mask
|
||||
|
||||
t_cond = self.t_embedding(t)
|
||||
t_cond = self.t_embedding(t, dtype=x.dtype)
|
||||
if t.dim() == 1:
|
||||
t_cond = t_cond.unsqueeze(1)
|
||||
adaln_input = F.silu(self.adaln_proj(t_cond))
|
||||
|
||||
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from comfy.ldm.lightricks.model import Timesteps
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
@ -17,9 +18,7 @@ def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape).to(dtype=x.dtype)
|
||||
return apply_rope1(x, freqs_cis)
|
||||
|
||||
|
||||
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -44,6 +44,7 @@ import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.wan.ar_model
|
||||
import comfy.ldm.cube.gpt
|
||||
import comfy.ldm.wan.model_wandancer
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.triposplat.model
|
||||
@ -1903,6 +1904,26 @@ class Hunyuan3Dv2(BaseModel):
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class Cube3D(BaseModel):
|
||||
"""Roblox Cube3D shape GPT (autoregressive). Generation goes through the
|
||||
dedicated `cube` sampler (SamplerCustomAdvanced), never KSampler/apply_model."""
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cube.gpt.DualStreamRoformer)
|
||||
|
||||
def _apply_model(self, *args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"Cube3D is an autoregressive token model. Use the 'cube' sampler "
|
||||
"(SamplerCube + SamplerCustomAdvanced), not KSampler."
|
||||
)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class Hunyuan3Dv2_1(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)
|
||||
|
||||
@ -654,6 +654,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}shape_proj.weight'.format(key_prefix) in state_dict_keys and '{}lm_head.weight'.format(key_prefix) in state_dict_keys: # Roblox Cube3D shape GPT
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cube3d"
|
||||
n_embd = state_dict['{}transformer.wte.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["n_embd"] = n_embd
|
||||
dit_config["shape_model_vocab_size"] = state_dict['{}transformer.wte.weight'.format(key_prefix)].shape[0] - 3
|
||||
dit_config["n_layer"] = count_blocks(state_dict_keys, '{}transformer.dual_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["n_single_layer"] = count_blocks(state_dict_keys, '{}transformer.single_blocks.'.format(key_prefix) + '{}.')
|
||||
head_dim = state_dict['{}transformer.dual_blocks.0.attn.pre_x.q_norm.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["n_head"] = n_embd // head_dim
|
||||
dit_config["shape_model_embed_dim"] = state_dict['{}shape_proj.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["text_model_embed_dim"] = state_dict['{}text_proj.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["use_bbox"] = '{}bbox_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["bias"] = '{}text_proj.bias'.format(key_prefix) in state_dict_keys
|
||||
dit_config["rope_theta"] = 10000 # not stored in the state dict; upstream's fixed constant
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
|
||||
dit_config = {}
|
||||
|
||||
@ -643,6 +643,8 @@ def free_pins(size, evict_active=False):
|
||||
return freed_total
|
||||
|
||||
def ensure_pin_budget(size, evict_active=False):
|
||||
if args.high_ram:
|
||||
return True
|
||||
if args.fast_disk:
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
else:
|
||||
@ -1496,6 +1498,8 @@ if not args.disable_pinned_memory:
|
||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||
|
||||
def pinned_hostbuf_size(size):
|
||||
if args.high_ram:
|
||||
return max(0, int(size * 2))
|
||||
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
|
||||
|
||||
def discard_cuda_async_error():
|
||||
|
||||
@ -180,7 +180,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
if pin is not None:
|
||||
cast_maybe_lowvram_patch([pin], dest, offload_stream)
|
||||
return
|
||||
if signature is None:
|
||||
if signature is None or args.high_ram:
|
||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
|
||||
|
||||
34
comfy/sd.py
34
comfy/sd.py
@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.cube.vae
|
||||
import comfy.ldm.triposplat.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.cogvideo.vae
|
||||
@ -777,6 +778,39 @@ class VAE:
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
# Roblox Cube3D shape tokenizer (OneDAutoEncoder, decode-only)
|
||||
elif "bottleneck.block.codebook.weight" in sd:
|
||||
self.latent_dim = 1
|
||||
# The VQ bottleneck (get_codebook/lookup_codebook) reads raw parameters
|
||||
# outside any hooked forward, so the streaming-offload cast hooks can't
|
||||
# relocate them; the model must be fully resident to decode. This is a
|
||||
# correctness requirement, declared via the standard flag (like the audio
|
||||
# VAEs) rather than managed manually in the node.
|
||||
self.disable_offload = True
|
||||
embed_dim = sd["bottleneck.block.codebook.weight"].shape[1]
|
||||
num_codes = sd["bottleneck.block.codebook.weight"].shape[0]
|
||||
width = sd["bottleneck.block.c_out.weight"].shape[0]
|
||||
num_encoder_latents = sd["decoder.positional_encodings"].shape[0]
|
||||
head_dim = sd["decoder.blocks.0.attn.q_norm.weight"].shape[0]
|
||||
num_heads = width // head_dim
|
||||
num_freqs = sd["embedder.weight"].shape[1]
|
||||
num_decoder_layers = len({k.split(".")[2] for k in sd if k.startswith("decoder.blocks.")})
|
||||
self.first_stage_model = comfy.ldm.cube.vae.CubeShapeVAE(
|
||||
num_encoder_latents=num_encoder_latents, embed_dim=embed_dim, width=width,
|
||||
num_heads=num_heads, num_freqs=num_freqs, num_decoder_layers=num_decoder_layers,
|
||||
num_codes=num_codes,
|
||||
)
|
||||
# Decode goes through the managed comfy.sd.VAE.decode path; the grid logits
|
||||
# are float32 regardless of weight dtype, so keep process_output identity
|
||||
# (the default clamps to [0, 1] in-place and would destroy the isosurface).
|
||||
self.process_output = lambda image: image
|
||||
self.process_input = lambda image: image
|
||||
# shape is the token-ID latent (B, 1, num_tokens); size by num_tokens.
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[-1] * 768) * model_management.dtype_size(dtype)
|
||||
# fp32-only (unlike most VAEs that allow fp16/bf16): the VQ codebook lookup
|
||||
# and occupancy-grid query must run in fp32 to match upstream and keep the
|
||||
# isosurface stable.
|
||||
self.working_dtypes = [torch.float32]
|
||||
|
||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
||||
|
||||
@ -1550,6 +1550,32 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||
|
||||
|
||||
class Cube3D(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "cube3d",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {}
|
||||
|
||||
latent_format = latent_formats.Cube3D
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
# Upstream keeps fp32 weights and uses bf16 autocast during the forward pass
|
||||
# (see sample_cube). Prefer fp32 weights for parity; bf16 is the low-VRAM fallback.
|
||||
supported_inference_dtypes = [torch.float32, torch.bfloat16]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.Cube3D(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
# No bundled text encoder: the cube checkpoint is GPT-only. The graph wires a
|
||||
# standard CLIPLoader(clip-l)/CLIPTextEncode, so there is no clip_target to build.
|
||||
return None
|
||||
|
||||
class TripoSplat(supported_models_base.BASE):
|
||||
# Image -> 3D gaussian splat flow denoiser
|
||||
unet_config = {
|
||||
@ -2292,6 +2318,7 @@ models = [
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
Hunyuan3Dv2_1,
|
||||
Cube3D,
|
||||
TripoSplat,
|
||||
HiDream,
|
||||
HiDreamO1,
|
||||
|
||||
@ -27,10 +27,13 @@ class VideoInput(ABC):
|
||||
path: Union[str, IO[bytes]],
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
|
||||
bit_depth selects the encoded bit depth; None keeps the video's native depth.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -83,6 +86,14 @@ class VideoInput(ABC):
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
"""
|
||||
Returns the bit depth of the video (e.g. 8 or 10).
|
||||
|
||||
Default implementation returns 8; subclasses report their real depth.
|
||||
"""
|
||||
return 8
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
@ -52,6 +52,12 @@ def get_open_write_kwargs(
|
||||
return open_kwargs
|
||||
|
||||
|
||||
def video_stream_bit_depth(stream) -> int:
|
||||
if stream is None or stream.format is None or not stream.format.components:
|
||||
return 8
|
||||
return max(component.bits for component in stream.format.components)
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
@ -97,6 +103,13 @@ class VideoFromFile(VideoInput):
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||
return video_stream_bit_depth(video_stream)
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
@ -257,6 +270,7 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
image_format = 'gbrpf32le'
|
||||
process_image_format = lambda a: a
|
||||
align_graph = None
|
||||
audio = None
|
||||
|
||||
streams = [video_stream]
|
||||
@ -310,7 +324,24 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
checked_alpha = True
|
||||
|
||||
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
|
||||
# Fix non-deterministic video decode when the video width is not a multiple of 32
|
||||
# For non-yuvj pixel formats (all H.264/H.265 video)
|
||||
if image_format in ('gbrpf32le', 'gbrapf32le') and frame.width % 32 != 0:
|
||||
if align_graph is None:
|
||||
pad_w = ((frame.width + 31) // 32) * 32
|
||||
g = av.filter.Graph()
|
||||
g_src = g.add_buffer(width=frame.width, height=frame.height,
|
||||
format=frame.format.name, time_base=video_stream.time_base)
|
||||
g_pad = g.add('pad', f'{pad_w}:{frame.height}:0:0')
|
||||
g_sink = g.add('buffersink')
|
||||
g_src.link_to(g_pad)
|
||||
g_pad.link_to(g_sink)
|
||||
g.configure()
|
||||
align_graph = (g, g_src, g_sink)
|
||||
align_graph[1].push(frame)
|
||||
img = np.ascontiguousarray(align_graph[2].pull().to_ndarray(format=image_format)[:, :frame.width])
|
||||
else:
|
||||
img = frame.to_ndarray(format=image_format)
|
||||
if frame.rotation != 0:
|
||||
k = int(round(frame.rotation // 90))
|
||||
img = np.rot90(img, k=k, axes=(0, 1)).copy()
|
||||
@ -377,25 +408,32 @@ class VideoFromFile(VideoInput):
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||
video_encoding = video_stream.codec.name if video_stream is not None else None
|
||||
source_bit_depth = video_stream_bit_depth(video_stream)
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
if bit_depth is not None and video_encoding is not None and bit_depth != source_bit_depth:
|
||||
reuse_streams = False
|
||||
if self.__start_time or self.__duration:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
if bit_depth is None:
|
||||
bit_depth = source_bit_depth
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path, format=format, codec=codec, metadata=metadata
|
||||
path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth,
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
@ -451,8 +489,10 @@ class VideoFromComponents(VideoInput):
|
||||
Class representing video input from tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
def __init__(self, components: VideoComponents, bit_depth: int = 8):
|
||||
self.__components = components
|
||||
# Tensor components have no inherent bit depth; this is the depth used when encoding.
|
||||
self.__bit_depth = bit_depth
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
@ -461,18 +501,26 @@ class VideoFromComponents(VideoInput):
|
||||
frame_rate=self.__components.frame_rate,
|
||||
)
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
return self.__bit_depth
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
"""Save the video to a file path or BytesIO buffer."""
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
# None means "use the depth this video was created with" (CreateVideo's choice).
|
||||
if bit_depth is None:
|
||||
bit_depth = self.__bit_depth
|
||||
is_10bit = bit_depth >= 10
|
||||
extra_kwargs = {}
|
||||
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||
extra_kwargs["format"] = format.value
|
||||
@ -488,10 +536,11 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
pix_fmt = "yuv420p10le" if is_10bit else "yuv420p"
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
video_stream.pix_fmt = pix_fmt
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
@ -505,9 +554,14 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
if is_10bit:
|
||||
# 16-bit RGB keeps float precision through the conversion to 10-bit YUV.
|
||||
img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format="rgb48le")
|
||||
else:
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format=pix_fmt)
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
|
||||
@ -67,15 +67,6 @@ class RunwayImageToVideoResponse(BaseModel):
|
||||
id: Optional[str] = Field(None, description='Task ID')
|
||||
|
||||
|
||||
class RunwayTaskStatusEnum(str, Enum):
|
||||
SUCCEEDED = 'SUCCEEDED'
|
||||
RUNNING = 'RUNNING'
|
||||
FAILED = 'FAILED'
|
||||
PENDING = 'PENDING'
|
||||
CANCELLED = 'CANCELLED'
|
||||
THROTTLED = 'THROTTLED'
|
||||
|
||||
|
||||
class RunwayTaskStatusResponse(BaseModel):
|
||||
createdAt: datetime = Field(..., description='Task creation timestamp')
|
||||
id: str = Field(..., description='Task ID')
|
||||
@ -86,7 +77,7 @@ class RunwayTaskStatusResponse(BaseModel):
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
status: RunwayTaskStatusEnum
|
||||
status: str = Field(..., description="SUCCEEDED, RUNNING, FAILED, PENDING, CANCELLED or THROTTLED")
|
||||
|
||||
|
||||
class Model4(str, Enum):
|
||||
@ -125,3 +116,144 @@ class RunwayTextToImageRequest(BaseModel):
|
||||
|
||||
class RunwayTextToImageResponse(BaseModel):
|
||||
id: Optional[str] = Field(None, description='Task ID')
|
||||
|
||||
|
||||
class RunwayAleph2IO:
|
||||
"""Custom socket types for chaining Aleph2 guidance images."""
|
||||
|
||||
KEYFRAME = "RUNWAY_ALEPH2_KEYFRAME"
|
||||
PROMPT_IMAGE = "RUNWAY_ALEPH2_PROMPT_IMAGE"
|
||||
|
||||
|
||||
# Keyframe timing modes (anchored to the INPUT video). Stored on the chain item and used to
|
||||
# choose the request model below. The values match the Aleph2 keyframe union field names.
|
||||
KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the input video
|
||||
KEYFRAME_MODE_AT = "at" # fraction [0.0, 1.0] of the input video duration
|
||||
|
||||
# Prompt-image position modes (anchored to the OUTPUT video). Values match the Aleph2 position `type`.
|
||||
PROMPT_IMAGE_MODE_TIMESTAMP = "timestamp" # absolute time, in seconds, from the start of the output video
|
||||
PROMPT_IMAGE_MODE_POSITION = "position" # fraction [0.0, 1.0] of the output video duration
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeItem:
|
||||
"""A guidance image anchored to a point of the INPUT video (one Aleph2 ``keyframe``)."""
|
||||
|
||||
def __init__(self, image, mode: str, value: float):
|
||||
self.image = image
|
||||
self.mode = mode # KEYFRAME_MODE_SECONDS | KEYFRAME_MODE_AT
|
||||
self.value = value
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeChain:
|
||||
"""An ordered collection of keyframes, built by chaining Runway Aleph2 Keyframe nodes."""
|
||||
|
||||
def __init__(self):
|
||||
self.items: list[RunwayAleph2KeyframeItem] = []
|
||||
|
||||
def add(self, item: RunwayAleph2KeyframeItem) -> None:
|
||||
self.items.append(item)
|
||||
|
||||
def clone(self) -> "RunwayAleph2KeyframeChain":
|
||||
c = RunwayAleph2KeyframeChain()
|
||||
c.items = list(self.items)
|
||||
return c
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageItem:
|
||||
"""A guidance image anchored to a point of the OUTPUT video (one Aleph2 ``promptImage``)."""
|
||||
|
||||
def __init__(self, image, mode: str, value: float):
|
||||
self.image = image
|
||||
self.mode = mode # PROMPT_IMAGE_MODE_TIMESTAMP | PROMPT_IMAGE_MODE_POSITION
|
||||
self.value = value
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageChain:
|
||||
"""An ordered collection of prompt images, built by chaining Runway Aleph2 Prompt Image nodes."""
|
||||
|
||||
def __init__(self):
|
||||
self.items: list[RunwayAleph2PromptImageItem] = []
|
||||
|
||||
def add(self, item: RunwayAleph2PromptImageItem) -> None:
|
||||
self.items.append(item)
|
||||
|
||||
def clone(self) -> "RunwayAleph2PromptImageChain":
|
||||
c = RunwayAleph2PromptImageChain()
|
||||
c.items = list(self.items)
|
||||
return c
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeSeconds(BaseModel):
|
||||
seconds: float = Field(
|
||||
...,
|
||||
description="Absolute timestamp in seconds from the start of the input video when this guidance image should apply.",
|
||||
ge=0.0,
|
||||
)
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeAt(BaseModel):
|
||||
at: float = Field(
|
||||
...,
|
||||
description="Position as a fraction [0.0, 1.0] of the input video duration.",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2TimestampPosition(BaseModel):
|
||||
type: str = Field(default="timestamp")
|
||||
timestampSeconds: float = Field(
|
||||
...,
|
||||
description="Absolute timestamp in seconds from the start of the output video.",
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2RelativePosition(BaseModel):
|
||||
type: str = Field(default="position")
|
||||
positionPercentage: float = Field(
|
||||
...,
|
||||
description="Position as a fraction [0.0, 1.0] of the total output video duration.",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2PromptImage(BaseModel):
|
||||
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2ContentModeration(BaseModel):
|
||||
publicFigureThreshold: str = Field(
|
||||
...,
|
||||
description='When set to "low", the content moderation system is less strict about '
|
||||
'recognizable public figures. One of "auto" or "low".',
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2Request(BaseModel):
|
||||
model: str = Field(default="aleph2")
|
||||
promptText: str = Field(
|
||||
...,
|
||||
description="A non-empty string describing what should appear in the output.",
|
||||
min_length=1,
|
||||
max_length=1000,
|
||||
)
|
||||
videoUri: str = Field(...)
|
||||
seed: int = Field(..., description="Random seed for generation", ge=0, le=4294967295)
|
||||
contentModeration: RunwayAleph2ContentModeration = Field(...)
|
||||
keyframes: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] | None = Field(
|
||||
None,
|
||||
description="Timed guidance images placed at specific points in the input video. Up to 5.",
|
||||
)
|
||||
promptImage: list[RunwayAleph2PromptImage] | None = Field(
|
||||
None,
|
||||
description="Up to 5 image keyframes for guiding the edit at specific points in the output video.",
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2Response(BaseModel):
|
||||
id: str | None = Field(None, description="Task ID")
|
||||
|
||||
@ -208,6 +208,10 @@ class TripoMultiviewToModelRequest(BaseModel):
|
||||
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
|
||||
|
||||
|
||||
class TripoTexturePrompt(BaseModel):
|
||||
text: str | None = Field(None, description="Text guidance for texture generation")
|
||||
|
||||
|
||||
class TripoTextureModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description="Type of task")
|
||||
original_model_task_id: str = Field(..., description="The task ID of the original model")
|
||||
@ -219,6 +223,11 @@ class TripoTextureModelRequest(BaseModel):
|
||||
texture_alignment: TripoTextureAlignment | None = Field(
|
||||
TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method"
|
||||
)
|
||||
texture_prompt: TripoTexturePrompt | None = Field(
|
||||
None,
|
||||
description="Optional guidance for texturing. Required in practice for imported models, "
|
||||
"which carry no source image to infer texture from.",
|
||||
)
|
||||
|
||||
|
||||
class TripoRefineModelRequest(BaseModel):
|
||||
@ -307,6 +316,17 @@ class TripoP1MultiviewToModelRequest(TripoP1CommonRequest):
|
||||
orientation: str | None = None
|
||||
|
||||
|
||||
class TripoImportModelRequest(BaseModel):
|
||||
"""Request for the comfy-api composite import endpoint (/proxy/tripo/v2/openapi/import).
|
||||
|
||||
The model file is uploaded to ComfyUI API storage first; the backend downloads it from
|
||||
`url`, re-uploads it to Tripo's storage and creates the import_model task server-side.
|
||||
"""
|
||||
|
||||
url: str = Field(..., description="ComfyUI API storage download URL of the model file")
|
||||
format: str = Field(..., description='File format: "glb", "fbx", "obj" or "stl"')
|
||||
|
||||
|
||||
class TripoTaskOutput(BaseModel):
|
||||
model: str | None = Field(None, description="URL to the model")
|
||||
base_model: str | None = Field(None, description="URL to the base model")
|
||||
|
||||
@ -30,13 +30,33 @@ from comfy_api_nodes.apis.runway import (
|
||||
Model4,
|
||||
ReferenceImage,
|
||||
RunwayTextToImageAspectRatioEnum,
|
||||
RunwayAleph2IO,
|
||||
RunwayAleph2KeyframeChain,
|
||||
RunwayAleph2KeyframeItem,
|
||||
RunwayAleph2PromptImageChain,
|
||||
RunwayAleph2PromptImageItem,
|
||||
RunwayAleph2Request,
|
||||
RunwayAleph2Response,
|
||||
RunwayAleph2KeyframeSeconds,
|
||||
RunwayAleph2KeyframeAt,
|
||||
RunwayAleph2PromptImage,
|
||||
RunwayAleph2TimestampPosition,
|
||||
RunwayAleph2RelativePosition,
|
||||
RunwayAleph2ContentModeration,
|
||||
KEYFRAME_MODE_SECONDS,
|
||||
KEYFRAME_MODE_AT,
|
||||
PROMPT_IMAGE_MODE_TIMESTAMP,
|
||||
PROMPT_IMAGE_MODE_POSITION,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
image_tensor_pair_to_batch,
|
||||
validate_string,
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio,
|
||||
validate_video_duration,
|
||||
upload_images_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
download_url_to_image_tensor,
|
||||
ApiEndpoint,
|
||||
@ -45,6 +65,7 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_VIDEO_TO_VIDEO = "/proxy/runway/video_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
|
||||
|
||||
@ -53,12 +74,6 @@ AVERAGE_DURATION_FLF_SECONDS = 256
|
||||
AVERAGE_DURATION_T2I_SECONDS = 41
|
||||
|
||||
|
||||
class RunwayApiError(Exception):
|
||||
"""Base exception for Runway API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RunwayGen4TurboAspectRatio(str, Enum):
|
||||
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
|
||||
|
||||
@ -84,14 +99,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> float | None:
|
||||
if hasattr(response, "progress") and response.progress is not None:
|
||||
return response.progress * 100
|
||||
return None
|
||||
|
||||
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
"""Returns the image URL from the task status response if it exists."""
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
@ -102,14 +109,13 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
async def get_response(
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status.value,
|
||||
status_extractor=lambda r: r.status,
|
||||
estimated_duration=estimated_duration,
|
||||
progress_extractor=extract_progress_from_task_status,
|
||||
progress_extractor=lambda r: r.progress * 100 if r.progress is not None else None,
|
||||
)
|
||||
|
||||
|
||||
@ -127,7 +133,7 @@ async def generate_video(
|
||||
|
||||
final_response = await get_response(cls, initial_response.id, estimated_duration)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
||||
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
video_url = get_video_url_from_task_status(final_response)
|
||||
return await download_url_to_video_output(video_url)
|
||||
@ -410,7 +416,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
mime_type="image/png",
|
||||
)
|
||||
if len(download_urls) != 2:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
raise ValueError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
@ -514,11 +520,321 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no image data found in response.")
|
||||
raise ValueError("Runway task succeeded but no image data found in response.")
|
||||
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
||||
|
||||
|
||||
_TIMING_ABSOLUTE = "Absolute time (seconds)"
|
||||
_TIMING_FRACTION = "Fraction of duration (0.0-1.0)"
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2KeyframeNode",
|
||||
display_name="Runway Aleph2 Keyframe",
|
||||
category="partner/video/Runway",
|
||||
description="Anchor a guidance image to a moment of the input (source) video, so Aleph2 "
|
||||
"steers the edit at that point of your footage. Connect this to the 'keyframes' input of "
|
||||
"the Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||
"'keyframes' input below.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The guidance image to apply at the chosen moment of the input video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"timing",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_ABSOLUTE,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"seconds",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=30.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Time in seconds from start of the input video where this image applies.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_FRACTION,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"fraction",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Where in the input video this image applies, "
|
||||
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="How to place this image on the input video's timeline.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||
"keyframes",
|
||||
optional=True,
|
||||
tooltip="Optional earlier keyframes to chain with this one.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(RunwayAleph2IO.KEYFRAME).Output(display_name="keyframes")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
timing: dict,
|
||||
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain = keyframes.clone() if keyframes is not None else RunwayAleph2KeyframeChain()
|
||||
if timing["timing"] == _TIMING_ABSOLUTE:
|
||||
mode, value = KEYFRAME_MODE_SECONDS, float(timing["seconds"])
|
||||
else:
|
||||
mode, value = KEYFRAME_MODE_AT, float(timing["fraction"])
|
||||
chain.add(RunwayAleph2KeyframeItem(image=image, mode=mode, value=value))
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2PromptImageNode",
|
||||
display_name="Runway Aleph2 Prompt Image",
|
||||
category="partner/video/Runway",
|
||||
description="Anchor a guidance image to a moment of the output (result) video, to guide what "
|
||||
"the edited video looks like at that point. Connect this to the 'prompt_images' input of the "
|
||||
"Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||
"'prompt_images' input below.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The guidance image to place at the chosen moment of the output video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"position",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_ABSOLUTE,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"seconds",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=30.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Time in seconds from start of the output video where this image applies.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_FRACTION,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"fraction",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Where in the output video this image applies, "
|
||||
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="How to place this image on the output video's timeline.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||
"prompt_images",
|
||||
optional=True,
|
||||
tooltip="Optional earlier prompt images to chain with this one.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Output(display_name="prompt_images")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
position: dict,
|
||||
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain = prompt_images.clone() if prompt_images is not None else RunwayAleph2PromptImageChain()
|
||||
if position["position"] == _TIMING_ABSOLUTE:
|
||||
mode, value = PROMPT_IMAGE_MODE_TIMESTAMP, float(position["seconds"])
|
||||
else:
|
||||
mode, value = PROMPT_IMAGE_MODE_POSITION, float(position["fraction"])
|
||||
chain.add(RunwayAleph2PromptImageItem(image=image, mode=mode, value=value))
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class RunwayAleph2VideoToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2VideoToVideoNode",
|
||||
display_name="Runway Aleph2 Video to Video",
|
||||
category="partner/video/Runway",
|
||||
description="Edit a video with a text prompt using Runway's Aleph2 model. Aleph2 transforms "
|
||||
"your footage (restyle, relight, add or remove elements, change the viewpoint) while keeping "
|
||||
"the original motion and timing; the output resolution matches the input video, which must be "
|
||||
"2-30 seconds at 30 fps or lower. Optionally steer the edit with either keyframes (anchored to "
|
||||
"the input video) or prompt images (anchored to the output video) - use one or the other, not both.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Describes what should appear in the output (1-1000 characters).",
|
||||
),
|
||||
IO.Video.Input(
|
||||
"video",
|
||||
tooltip="Input video to edit. Must be 2-30 seconds at 30 fps or lower.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"public_figure_threshold",
|
||||
options=["auto", "low"],
|
||||
default="low",
|
||||
tooltip="Content moderation for recognizable public figures.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||
"keyframes",
|
||||
optional=True,
|
||||
tooltip="Guidance images anchored to the input video, from Aleph2 Keyframe nodes (up to 5). "
|
||||
"Use keyframes or prompt images, not both.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||
"prompt_images",
|
||||
optional=True,
|
||||
tooltip="Guidance images anchored to the output video, from Aleph2 Prompt Image nodes (up to 5). "
|
||||
"Use keyframes or prompt images, not both.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.4004, "format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
seed: int,
|
||||
public_figure_threshold: str = "low",
|
||||
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=1000)
|
||||
validate_video_duration(
|
||||
video,
|
||||
min_duration=2.0,
|
||||
max_duration=30.0,
|
||||
)
|
||||
try:
|
||||
fps = float(video.get_frame_rate())
|
||||
except Exception:
|
||||
fps = None
|
||||
if fps is not None and fps > 30.0 + 0.01:
|
||||
raise ValueError(f"Input video frame rate ({fps:.2f} fps) exceeds Aleph2's maximum of 30 fps.")
|
||||
|
||||
if (keyframes and keyframes.items) and (prompt_images and prompt_images.items):
|
||||
raise ValueError("Aleph2 accepts either keyframes or prompt images, not both.")
|
||||
|
||||
video_duration: float | None = None
|
||||
try:
|
||||
video_duration = video.get_duration()
|
||||
except Exception:
|
||||
video_duration = None
|
||||
|
||||
def _check_seconds(value: float, label: str) -> None:
|
||||
if video_duration is not None and value > video_duration + 0.0001:
|
||||
raise ValueError(f"{label} {value:.2f}s exceeds the input video duration ({video_duration:.2f}s).")
|
||||
|
||||
video_url = await upload_video_to_comfyapi(cls, video)
|
||||
|
||||
keyframe_models: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] = []
|
||||
if keyframes is not None:
|
||||
if len(keyframes.items) > 5:
|
||||
raise ValueError("Aleph2 supports at most 5 keyframes.")
|
||||
for item in keyframes.items:
|
||||
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||
if item.mode == KEYFRAME_MODE_SECONDS:
|
||||
_check_seconds(item.value, "Keyframe timestamp")
|
||||
keyframe_models.append(RunwayAleph2KeyframeSeconds(seconds=item.value, uri=image_url))
|
||||
else:
|
||||
keyframe_models.append(RunwayAleph2KeyframeAt(at=item.value, uri=image_url))
|
||||
|
||||
prompt_image_models: list[RunwayAleph2PromptImage] = []
|
||||
if prompt_images is not None:
|
||||
if len(prompt_images.items) > 5:
|
||||
raise ValueError("Aleph2 supports at most 5 prompt images.")
|
||||
for item in prompt_images.items:
|
||||
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||
if item.mode == PROMPT_IMAGE_MODE_TIMESTAMP:
|
||||
_check_seconds(item.value, "Prompt image timestamp")
|
||||
position = RunwayAleph2TimestampPosition(timestampSeconds=item.value)
|
||||
else:
|
||||
position = RunwayAleph2RelativePosition(positionPercentage=item.value)
|
||||
prompt_image_models.append(RunwayAleph2PromptImage(position=position, uri=image_url))
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_VIDEO_TO_VIDEO, method="POST"),
|
||||
response_model=RunwayAleph2Response,
|
||||
data=RunwayAleph2Request(
|
||||
promptText=prompt,
|
||||
videoUri=video_url,
|
||||
seed=seed,
|
||||
contentModeration=RunwayAleph2ContentModeration(publicFigureThreshold=public_figure_threshold),
|
||||
keyframes=keyframe_models or None,
|
||||
promptImage=prompt_image_models or None,
|
||||
),
|
||||
)
|
||||
|
||||
final_response = await get_response(cls, initial_response.id)
|
||||
if not final_response.output:
|
||||
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(final_response)))
|
||||
|
||||
|
||||
class RunwayExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -527,6 +843,9 @@ class RunwayExtension(ComfyExtension):
|
||||
RunwayImageToVideoNodeGen3a,
|
||||
RunwayImageToVideoNodeGen4,
|
||||
RunwayTextToImageNode,
|
||||
RunwayAleph2VideoToVideoNode,
|
||||
RunwayAleph2KeyframeNode,
|
||||
RunwayAleph2PromptImageNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.tripo import (
|
||||
TripoAnimateRetargetRequest,
|
||||
TripoAnimateRigRequest,
|
||||
@ -8,6 +8,7 @@ from comfy_api_nodes.apis.tripo import (
|
||||
TripoFileEmptyReference,
|
||||
TripoFileReference,
|
||||
TripoImageToModelRequest,
|
||||
TripoImportModelRequest,
|
||||
TripoModelVersion,
|
||||
TripoMultiviewToModelRequest,
|
||||
TripoOrientation,
|
||||
@ -21,6 +22,7 @@ from comfy_api_nodes.apis.tripo import (
|
||||
TripoTaskType,
|
||||
TripoTextToModelRequest,
|
||||
TripoTextureModelRequest,
|
||||
TripoTexturePrompt,
|
||||
TripoUrlReference,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
@ -28,6 +30,7 @@ from comfy_api_nodes.util import (
|
||||
download_url_to_file_3d,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_3d_model_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
)
|
||||
|
||||
@ -538,6 +541,14 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"texture_prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
optional=True,
|
||||
tooltip="Optional text guidance for texturing. Required in practice for imported "
|
||||
"models (Tripo: Import Model), which carry no source image to infer colors from.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
@ -571,6 +582,7 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
texture_alignment: str | None = None,
|
||||
texture_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
@ -583,6 +595,7 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
texture_alignment=texture_alignment,
|
||||
texture_prompt=TripoTexturePrompt(text=texture_prompt.strip()) if texture_prompt.strip() else None,
|
||||
),
|
||||
)
|
||||
return await poll_until_finished(cls, response, average_duration=80)
|
||||
@ -915,6 +928,90 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
return await poll_until_finished(cls, response, average_duration=30)
|
||||
|
||||
|
||||
class TripoImportModelNode(IO.ComfyNode):
|
||||
"""Imports an external 3D model into Tripo, producing a MODEL_TASK_ID for post-processing nodes."""
|
||||
|
||||
SUPPORTED_FORMATS = ("glb", "fbx", "obj", "stl")
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TripoImportModelNode",
|
||||
display_name="Tripo: Import Model",
|
||||
category="partner/3d/Tripo",
|
||||
description="Import an external 3D model (e.g. from Rodin, Hunyuan3D or a local file) into Tripo "
|
||||
"to use it with Tripo's post-processing nodes: Texture, Rig, Convert. "
|
||||
"GLB is recommended: textures survive import only when embedded in the file. "
|
||||
"Note that texturing an imported model requires a texture prompt.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DGLB, IO.File3DFBX, IO.File3DOBJ, IO.File3DSTL, IO.File3DAny],
|
||||
tooltip="3D model to import (GLB / FBX / OBJ / STL, up to 150 MB). "
|
||||
"OBJ and STL files carry no embedded textures.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"text","text":"Free"}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, model_3d: Types.File3D) -> IO.NodeOutput:
|
||||
file_format = (model_3d.format or "").lstrip(".").lower()
|
||||
if file_format == "gltf":
|
||||
raise ValueError(
|
||||
"GLTF (.gltf) references external files and cannot be imported. Export a single-file GLB instead."
|
||||
)
|
||||
if file_format not in cls.SUPPORTED_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported 3D format '{file_format or 'unknown'}'. "
|
||||
f"Tripo import supports: {', '.join(f.upper() for f in cls.SUPPORTED_FORMATS)}."
|
||||
)
|
||||
size = len(model_3d.get_bytes())
|
||||
if size > 150 * 1024 * 1024:
|
||||
raise ValueError(f"Model file is {size / (1024 * 1024):.1f} MB; Tripo import allows up to 150 MB.")
|
||||
|
||||
url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/import", method="POST"),
|
||||
response_model=TripoTaskResponse,
|
||||
data=TripoImportModelRequest(url=url, format=file_format),
|
||||
)
|
||||
if response.code != 0:
|
||||
raise RuntimeError(f"Failed to import model: {response.error}")
|
||||
|
||||
task_id = response.data.task_id
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/tripo/v2/openapi/task/{task_id}"),
|
||||
response_model=TripoTaskResponse,
|
||||
failed_statuses=[
|
||||
TripoTaskStatus.FAILED,
|
||||
TripoTaskStatus.CANCELLED,
|
||||
TripoTaskStatus.UNKNOWN,
|
||||
TripoTaskStatus.BANNED,
|
||||
TripoTaskStatus.EXPIRED,
|
||||
],
|
||||
status_extractor=lambda x: x.data.status,
|
||||
progress_extractor=lambda x: x.data.progress,
|
||||
estimated_duration=10,
|
||||
)
|
||||
if response_poll.data.status != TripoTaskStatus.SUCCESS:
|
||||
raise RuntimeError(f"Failed to import model: {response_poll}")
|
||||
return IO.NodeOutput(task_id)
|
||||
|
||||
|
||||
def _p1_price_expr(*, geometry_credits: int, textured_credits: int, detailed_credits: int) -> str:
|
||||
return (
|
||||
"("
|
||||
@ -1292,6 +1389,7 @@ class TripoExtension(ComfyExtension):
|
||||
TripoP1TextToModelNode,
|
||||
TripoP1ImageToModelNode,
|
||||
TripoP1MultiviewToModelNode,
|
||||
TripoImportModelNode,
|
||||
TripoTextureNode,
|
||||
TripoRefineNode,
|
||||
TripoRigNode,
|
||||
|
||||
156
comfy_extras/nodes_cube.py
Normal file
156
comfy_extras/nodes_cube.py
Normal file
@ -0,0 +1,156 @@
|
||||
"""
|
||||
Nodes for native Roblox Cube3D text-to-3D support.
|
||||
|
||||
Graph:
|
||||
CLIPLoader(clip-l) -> CLIPTextEncode -> CONDITIONING
|
||||
UNETLoader(shape_gpt) -> MODEL --\
|
||||
VAELoader(shape_tokenizer) -> VAE -> CubeCodebookPatch -> MODEL
|
||||
CFGGuider(MODEL, pos, neg, cfg) + SamplerCube + (trivial sigmas) + EmptyCubeLatent
|
||||
-> SamplerCustomAdvanced -> LATENT (token IDs)
|
||||
VAEDecodeCube(VAE, LATENT) -> MESH -> SaveGLB
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.ldm.cube.vae
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
from comfy_extras.nodes_save_3d import pack_variable_mesh_batch
|
||||
|
||||
|
||||
class EmptyCubeLatent(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyCubeLatent",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Int.Input("num_tokens", default=1024, min=1, max=8192,
|
||||
tooltip="Shape token sequence length. Must match the tokenizer "
|
||||
"(1024 for cube3d-v0.5, 512 for v0.1)."),
|
||||
IO.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[IO.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, num_tokens, batch_size) -> IO.NodeOutput:
|
||||
# Channels-first 1D latent (B, 1, num_tokens), mirroring Hunyuan3Dv2's (B, C, L)
|
||||
# convention (latent_channels=1). The sampler only uses the sequence length.
|
||||
latent = torch.zeros([batch_size, 1, num_tokens], device=comfy.model_management.intermediate_device())
|
||||
return IO.NodeOutput({"samples": latent, "type": "cube_tokens"})
|
||||
|
||||
|
||||
class CubeCodebookPatch(IO.ComfyNode):
|
||||
"""Inject the projected VQ codebook into the GPT token-embedding table.
|
||||
|
||||
Upstream copies shape_proj(tokenizer.codebook) into wte.weight[:num_codes] at load
|
||||
time; without it generation is garbage. Done here as a ModelPatcher object patch so
|
||||
it composes with normal model loading/offload."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="CubeCodebookPatch",
|
||||
display_name="Cube Codebook Patch",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[IO.Model.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, vae) -> IO.NodeOutput:
|
||||
gpt = model.get_model_object("diffusion_model")
|
||||
codebook = vae.first_stage_model.bottleneck.block.get_codebook() # (num_codes, embed_dim) fp32
|
||||
w = gpt.shape_proj.weight
|
||||
proj = gpt.shape_proj(codebook.to(device=w.device, dtype=w.dtype)) # (num_codes, n_embd)
|
||||
|
||||
old = model.get_model_object("diffusion_model.transformer.wte.weight")
|
||||
new = old.clone()
|
||||
new[:proj.shape[0]] = proj.to(device=new.device, dtype=new.dtype)
|
||||
|
||||
m = model.clone()
|
||||
m.add_object_patch("diffusion_model.transformer.wte.weight",
|
||||
torch.nn.Parameter(new, requires_grad=False))
|
||||
return IO.NodeOutput(m)
|
||||
|
||||
|
||||
class SamplerCube(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SamplerCube",
|
||||
display_name="Sampler Cube (autoregressive)",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
IO.Float.Input("top_p", default=1.0, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="1.0 = deterministic greedy (upstream default). "
|
||||
"<1.0 enables nucleus sampling."),
|
||||
],
|
||||
outputs=[IO.Sampler.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, top_p) -> IO.NodeOutput:
|
||||
return IO.NodeOutput(comfy.samplers.ksampler("cube", {"top_p": top_p}))
|
||||
|
||||
|
||||
class VAEDecodeCube(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeCube",
|
||||
display_name="VAE Decode Cube (3D)",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Float.Input("resolution_base", default=8.0, min=4.0, max=10.0, step=0.5,
|
||||
tooltip="Grid cells per axis = 2^resolution_base. 8.0 matches "
|
||||
"upstream default (257^3 grid)."),
|
||||
IO.Int.Input("chunk_size", default=100000, min=1000, max=2000000, advanced=True),
|
||||
],
|
||||
outputs=[IO.Mesh.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, samples, resolution_base, chunk_size) -> IO.NodeOutput:
|
||||
# Managed decode: comfy.sd.VAE.decode handles model loading + device/dtype and
|
||||
# returns the occupancy grid logits (B, gx, gy, gz). Marching cubes runs here.
|
||||
grid = vae.decode(samples["samples"],
|
||||
vae_options={"resolution_base": resolution_base, "chunk_size": chunk_size})
|
||||
|
||||
bounds = vae.first_stage_model.decode_bounds
|
||||
bbox_min = np.array(bounds[0:3])
|
||||
bbox_size = np.array(bounds[3:6]) - bbox_min
|
||||
grid_size = list(grid.shape[1:])
|
||||
|
||||
verts_list, faces_list = [], []
|
||||
for i in range(grid.shape[0]):
|
||||
v, f = comfy.ldm.cube.vae.grid_logits_to_mesh(grid[i], grid_size, bbox_size, bbox_min)
|
||||
verts_list.append(torch.from_numpy(v))
|
||||
faces_list.append(torch.from_numpy(f.astype(np.int64)))
|
||||
|
||||
mesh = pack_variable_mesh_batch(verts_list, faces_list)
|
||||
return IO.NodeOutput(mesh)
|
||||
|
||||
|
||||
class CubeExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
EmptyCubeLatent,
|
||||
CubeCodebookPatch,
|
||||
SamplerCube,
|
||||
VAEDecodeCube,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> CubeExtension:
|
||||
return CubeExtension()
|
||||
@ -134,6 +134,17 @@ class CreateVideo(io.ComfyNode):
|
||||
io.Image.Input("images", tooltip="The images to create a video from."),
|
||||
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
||||
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
||||
io.Int.Input(
|
||||
"bit_depth",
|
||||
min=8,
|
||||
max=10,
|
||||
default=8,
|
||||
step=2,
|
||||
tooltip="Bit depth of the created video. 10-bit keeps smoother gradients with less"
|
||||
" banding, but some players and downstream nodes may not support it.",
|
||||
optional=True,
|
||||
display_mode=io.NumberDisplay.number,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(),
|
||||
@ -141,9 +152,14 @@ class CreateVideo(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
|
||||
def execute(
|
||||
cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None, bit_depth: int = 8,
|
||||
) -> io.NodeOutput:
|
||||
return io.NodeOutput(
|
||||
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
InputImpl.VideoFromComponents(
|
||||
Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)),
|
||||
bit_depth=bit_depth,
|
||||
)
|
||||
)
|
||||
|
||||
class GetVideoComponents(io.ComfyNode):
|
||||
@ -154,7 +170,7 @@ class GetVideoComponents(io.ComfyNode):
|
||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||
display_name="Get Video Components",
|
||||
category="video",
|
||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||
description="Extracts all components from a video: frames, audio, framerate, and bit depth.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||
],
|
||||
@ -162,13 +178,14 @@ class GetVideoComponents(io.ComfyNode):
|
||||
io.Image.Output(display_name="images"),
|
||||
io.Audio.Output(display_name="audio"),
|
||||
io.Float.Output(display_name="fps"),
|
||||
io.Int.Output(display_name="bit_depth"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: Input.Video) -> io.NodeOutput:
|
||||
components = video.get_components()
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate), video.get_bit_depth())
|
||||
|
||||
|
||||
class LoadVideo(io.ComfyNode):
|
||||
|
||||
@ -1 +1 @@
|
||||
comfyui_manager==4.2.1
|
||||
comfyui_manager==4.2.2
|
||||
|
||||
1
nodes.py
1
nodes.py
@ -2433,6 +2433,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_kandinsky5.py",
|
||||
"nodes_wanmove.py",
|
||||
"nodes_ar_video.py",
|
||||
"nodes_cube.py",
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
"nodes_glsl.py",
|
||||
|
||||
@ -27,6 +27,7 @@ import logging
|
||||
|
||||
import mimetypes
|
||||
from comfy.cli_args import args
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy_api import feature_flags
|
||||
@ -690,6 +691,7 @@ class PromptServer():
|
||||
"python_version": sys.version,
|
||||
"pytorch_version": comfy.model_management.torch_version,
|
||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||
"deploy_environment": get_deploy_environment(),
|
||||
"argv": sys.argv
|
||||
},
|
||||
"devices": device_entries
|
||||
|
||||
93
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
93
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
import torch
|
||||
import av
|
||||
import numpy as np
|
||||
from fractions import Fraction
|
||||
from comfy_api.latest._input_impl.video_types import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util.video_types import VideoComponents
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gradient_components():
|
||||
"""Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth"""
|
||||
width, height, frames = 64, 64, 3
|
||||
ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3)
|
||||
return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def src8(gradient_components, tmp_path_factory):
|
||||
"""8-bit h264 mp4 (Create Video default)"""
|
||||
path = str(tmp_path_factory.mktemp("video") / "src8.mp4")
|
||||
VideoFromComponents(gradient_components).save_to(path)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def src10(gradient_components, tmp_path_factory):
|
||||
"""10-bit h264 mp4 (Create Video with bit_depth=10)"""
|
||||
path = str(tmp_path_factory.mktemp("video") / "src10.mp4")
|
||||
VideoFromComponents(gradient_components, bit_depth=10).save_to(path)
|
||||
return path
|
||||
|
||||
|
||||
def probe(path):
|
||||
"""(codec, pix_fmt, bit_depth) of the first video stream"""
|
||||
with av.open(path) as container:
|
||||
stream = container.streams.video[0]
|
||||
return (stream.codec.name, stream.format.name, max(c.bits for c in stream.format.components))
|
||||
|
||||
|
||||
def decoded_levels(path):
|
||||
"""Unique tonal levels in the first decoded frame (banding measure)"""
|
||||
with av.open(path) as container:
|
||||
frame = next(container.decode(container.streams.video[0]))
|
||||
return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0]))
|
||||
|
||||
|
||||
def video_packet_bytes(path):
|
||||
"""Raw video packet payloads; identical to the source's only for a true remux"""
|
||||
with av.open(path) as container:
|
||||
return [bytes(p) for p in container.demux(container.streams.video[0]) if p.size]
|
||||
|
||||
|
||||
def test_create_video_bit_depth(src8, src10):
|
||||
"""Create Video's bit_depth picks the encoded depth (default 8-bit); 10-bit reduces banding"""
|
||||
assert probe(src8) == ("h264", "yuv420p", 8)
|
||||
assert probe(src10) == ("h264", "yuv420p10le", 10)
|
||||
assert decoded_levels(src10) > 2 * decoded_levels(src8)
|
||||
|
||||
|
||||
def test_save_auto_keeps_source_depth(src8, src10, tmp_path):
|
||||
"""Save Video (no bit_depth = auto) stream-copies the source, preserving its depth byte-for-byte"""
|
||||
for name, src in [("p8", src8), ("p10", src10)]:
|
||||
path = str(tmp_path / f"{name}.mp4")
|
||||
VideoFromFile(src).save_to(path)
|
||||
assert probe(path) == probe(src)
|
||||
assert video_packet_bytes(path) == video_packet_bytes(src)
|
||||
|
||||
|
||||
def test_save_explicit_depth_reencodes(src8, src10, tmp_path):
|
||||
"""An explicit bit_depth different from the source forces a re-encode to that depth"""
|
||||
down = str(tmp_path / "down8.mp4")
|
||||
VideoFromFile(src10).save_to(down, bit_depth=8)
|
||||
assert probe(down) == ("h264", "yuv420p", 8)
|
||||
|
||||
up = str(tmp_path / "up10.mp4")
|
||||
VideoFromFile(src8).save_to(up, bit_depth=10)
|
||||
assert probe(up) == ("h264", "yuv420p10le", 10)
|
||||
|
||||
|
||||
def test_trim_keeps_source_depth(src10, tmp_path):
|
||||
"""Video Slice re-encodes (trim) but preserves the source's 10-bit depth"""
|
||||
path = str(tmp_path / "trim.mp4")
|
||||
VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path)
|
||||
assert probe(path) == ("h264", "yuv420p10le", 10)
|
||||
|
||||
|
||||
def test_get_bit_depth(gradient_components, src8, src10):
|
||||
"""get_bit_depth reports a video's depth (backs the Get Video Components output)"""
|
||||
assert VideoFromFile(src8).get_bit_depth() == 8
|
||||
assert VideoFromFile(src10).get_bit_depth() == 10
|
||||
assert VideoFromComponents(gradient_components, bit_depth=10).get_bit_depth() == 10
|
||||
assert VideoFromComponents(gradient_components).get_bit_depth() == 8
|
||||
Reference in New Issue
Block a user