Compare commits

..

17 Commits

Author SHA1 Message Date
1339cb570d Coalesce empty query params to None in /node_startup_errors route
?source= or ?module_name= or ?pack_id= (param present but blank) would have returned {} because the helper treated the empty string as an exact-match filter. Coalesce to None at the route boundary so a present-but-blank query param behaves the same as the param being absent. The helper's own behaviour is unchanged and locked in by a new assertion.

Amp-Thread-ID: https://ampcode.com/threads/T-019e86fd-b68f-74de-8c91-d2662377424a
Co-authored-by: Amp <amp@ampcode.com>
2026-06-01 23:39:32 -07:00
4eef53041e Match PyProjectConfig shape for pyproject; add pack_id/module_name/source query filters
Two reviewer-requested improvements to GET /node_startup_errors:

1. Emit the pyproject metadata in the same {project: {...}, tool_comfy: {...}}
   shape that comfy_config.config_parser.extract_node_configuration already
   returns, instead of inventing a flat {pack_id, display_name, ...} bag.
   API consumers can now parse the pyproject block straight through the
   shared PyProjectConfig pydantic model. Empty / default-valued leaves
   are pruned by a small recursive _prune_empty helper so the payload
   stays compact, but nesting and field names match the source-of-truth.

2. Add optional source, module_name, and pack_id query parameters
   (combined with AND) so a frontend / Manager can ask ?pack_id=foo
   instead of grep'ing through the whole grouped response. pack_id
   resolves against pyproject.project.name; entries without a parsed
   pyproject are naturally excluded from a pack_id query.

The grouping + filtering + module_path stripping moves into

odes.filter_node_startup_errors so the route handler is a one-liner and
the helper is unit-testable without spinning up an aiohttp app.

Tests: 5 new unit tests covering each filter branch, AND-combination, and
empty-result behaviour, plus an updated pyproject-metadata assertion that
checks the nested PyProjectConfig shape, plus a focused test for the
_prune_empty helper.
2026-06-01 23:33:37 -07:00
7259e664ef Defer record_node_startup_error in prestartup error path; add docstrings
Buffer prestartup failures into a module-level list inside main.py
instead of importing 'nodes' (and therefore 'torch') from within the
exception handler. After the normal 'import nodes' line, drain the
buffer via nodes.record_node_startup_error so bootstrap order stays
deterministic regardless of whether a prestartup script succeeded.

Also convert the explanatory '#' comment on the new
/node_startup_errors endpoint into a proper docstring and add a
docstring to execute_prestartup_script, addressing CodeRabbit's
docstring-coverage warning on this PR.

Addresses review feedback on PR #13184.

Amp-Thread-ID: https://ampcode.com/threads/T-019e2f90-26fe-7048-9855-5ff39d08a3e0
Co-authored-by: Amp <amp@ampcode.com>
2026-05-21 14:09:01 -07:00
ae539cfa0a Merge branch 'master' into feature/custom-node-startup-errors 2026-05-21 12:58:06 -07:00
b293f8cefd [Partner Nodes] add widget for automatic upscaling for the ByteDance2Reference node (#14032)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-21 11:58:03 -07:00
2ca1480f91 chore: update workflow templates to v0.9.82 (#14034) 2026-05-21 11:48:20 -07:00
6ecf5eca7a [Partner Nodes] add OpenRouter LLM node (#14007)
* [Partner Nodes] add reasoning widget to Anthropic node

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] add new OpenRouterLLM node

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] fix passing images to Grok LLM

Signed-off-by: bigcat88 <bigcat88@icloud.com>

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-21 11:36:11 -07:00
03e511862e Fix reshaping lora application (#14031)
* ModelPatcherDyanmic: purge stale vbar allocs on force cast

* ModelPatcherDynamic: restore backups before load

If doing a clean reload, mutative changes (lora application) could be
applied on-top of the already loaded weight. Restore from backup
unconditionally so that the new load is clean.
2026-05-21 09:47:16 -07:00
aab41a9ddb fix(lanczos): correct dimension transposition for single-channel tensors (#12679) 2026-05-21 23:47:20 +08:00
8f82b16993 Merge branch 'master' into feature/custom-node-startup-errors 2026-05-15 16:31:50 -07:00
72fe66a18b Hoist 'import traceback' to top of main.py
Minor cleanup from code review: traceback is stdlib so there's no circular-import concern keeping it inline. The 'from nodes import record_node_startup_error' stays inline because nodes.py imports from contexts that would create a cycle at module load time.

Ref: ComfyUI-Launcher#303
Amp-Thread-ID: https://ampcode.com/threads/T-019e23a1-2acc-7619-bd0e-f783d1368ef3
Co-authored-by: Amp <amp@ampcode.com>
2026-05-15 00:48:23 -07:00
07ff14ae02 Use module_parent string directly as 'source'; drop fixed-enum mapping
The public 'source' field on each NODE_STARTUP_ERRORS entry is now the same string as the internal module_parent passed to load_custom_node ('custom_nodes', 'comfy_extras', 'comfy_api_nodes'), rather than being translated to a separate fixed enum. Treating it as a free-form string keeps the contract durable in case the node-source layout evolves (e.g. comfy_api_nodes eventually moving out of core).

The API endpoint now also dynamically groups by whatever sources are present rather than hardcoding the three known top-level keys; consumers should not assume any particular set of keys is always present.

Drops the _NODE_SOURCE_BY_PARENT map, _node_source_from_parent helper, and the related test. Adds a test covering an arbitrary unknown module_parent value passing through unchanged.

Ref: ComfyUI-Launcher#303
Amp-Thread-ID: https://ampcode.com/threads/T-019e23a1-2acc-7619-bd0e-f783d1368ef3
Co-authored-by: Amp <amp@ampcode.com>
2026-05-14 20:49:35 -07:00
ba1c039a04 Rename /custom_node_startup_errors -> /node_startup_errors
The endpoint covers comfy_extras and comfy_api_nodes failures too, not just user-installed custom nodes, so the path should not pretend otherwise.

Ref: ComfyUI-Launcher#303
Amp-Thread-ID: https://ampcode.com/threads/T-019e23a1-2acc-7619-bd0e-f783d1368ef3
Co-authored-by: Amp <amp@ampcode.com>
2026-05-13 21:05:15 -07:00
6220400ad5 Strip absolute module_path from /custom_node_startup_errors response
The absolute on-disk path is internal detail the frontend/Manager has no use for. Keep it in the in-memory NODE_STARTUP_ERRORS dict for server-side debugging, but exclude it from the public API payload. The user-facing identifier remains module_name (and pyproject.pack_id when available).

Ref: ComfyUI-Launcher#303
Amp-Thread-ID: https://ampcode.com/threads/T-019e23a1-2acc-7619-bd0e-f783d1368ef3
Co-authored-by: Amp <amp@ampcode.com>
2026-05-13 18:10:50 -07:00
af55a2308f Attach pyproject.toml node-pack identity to startup error entries
When a failing module has a pyproject.toml, parse it via comfy_config.config_parser and attach a 'pyproject' field with the Comfy Registry-style identity (pack_id, display_name, publisher_id, version, repository). This gives the frontend/Manager a stable, user-recognizable handle for the failed pack beyond the on-disk folder name.

The lookup is best-effort and never raises: missing toml, missing pydantic-settings dependency, or any parse error simply omits the 'pyproject' key.

Ref: ComfyUI-Launcher#303
Amp-Thread-ID: https://ampcode.com/threads/T-019e23a1-2acc-7619-bd0e-f783d1368ef3
Co-authored-by: Amp <amp@ampcode.com>
2026-05-13 16:31:44 -07:00
3a649984f2 Categorize startup errors by source (custom_node / comfy_extra / api_node)
Expand custom-node startup error tracking to differentiate between user-installed custom_nodes, built-in comfy_extras, and partner comfy_api_nodes. Each NODE_STARTUP_ERRORS entry now carries a 'source' field and is keyed by '<source>:<module_name>' so colliding module names across the three locations don't overwrite each other. The /custom_node_startup_errors endpoint returns errors grouped by source so the frontend/Manager can render distinct sections.

Also captures previously-missed failures from comfy_entrypoint() (phase='entrypoint').

Introduces nodes.record_node_startup_error() helper used by load_custom_node and main.execute_prestartup_script.

Adds tests-unit/node_startup_errors_test.py (6 tests) covering field shape, source mapping for each module_parent, cross-source collisions, and default fallback.

Ref: ComfyUI-Launcher#303
Amp-Thread-ID: https://ampcode.com/threads/T-019e23a1-2acc-7619-bd0e-f783d1368ef3
Co-authored-by: Amp <amp@ampcode.com>
2026-05-13 16:29:17 -07:00
a145651cc0 Track custom node startup errors and expose via API endpoint
Store import and prestartup errors in NODE_STARTUP_ERRORS dict (nodes.py,
main.py) and add GET /custom_node_startup_errors endpoint (server.py) so
the frontend/Manager can distinguish failed imports from missing nodes.

Ref: ComfyUI-Launcher#303
Amp-Thread-ID: https://ampcode.com/threads/T-019d2346-6e6f-75e0-a97f-cdb6e26859f7
Co-authored-by: Amp <amp@ampcode.com>
2026-03-24 23:41:01 -07:00
30 changed files with 1145 additions and 5271 deletions

View File

@ -9,7 +9,6 @@ import comfy.model_management
import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
import comfy.image_encoders.dino3
class Output:
def __getitem__(self, key):
@ -24,7 +23,6 @@ IMAGE_ENCODERS = {
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
"dinov3": comfy.image_encoders.dino3.DINOv3ViTModel
}
class ClipVisionModel():
@ -136,8 +134,6 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
elif 'layer.9.attention.o_proj.bias' in sd: # dinov3
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino3_large.json")
else:
return None

View File

@ -1,285 +0,0 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
class DINOv3ViTMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
self.act_fn = torch.nn.GELU()
def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
num_tokens = q.shape[-2]
num_patches = sin.shape[-2]
num_prefix_tokens = num_tokens - num_patches
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
return q, k
class DINOv3ViTAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
super().__init__()
self.embed_dim = hidden_size
self.num_heads = num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
batch_size, patches, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
attn = optimized_attention_for_device(query_states.device, mask=False)
attn_output = attn(
query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class DINOv3ViTGatedMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
self.act_fn = torch.nn.GELU()
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def get_patches_center_coordinates(
num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
coords_h = coords_h / num_patches_h
coords_w = coords_w / num_patches_w
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
coords = coords.flatten(0, 1)
coords = 2.0 * coords - 1.0
return coords
class DINOv3ViTRopePositionEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, rope_theta, hidden_size, num_attention_heads, image_size, patch_size, device, dtype):
super().__init__()
self.base = rope_theta
self.head_dim = hidden_size // num_attention_heads
self.num_patches_h = image_size // patch_size
self.num_patches_w = image_size // patch_size
self.patch_size = patch_size
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
_, _, height, width = pixel_values.shape
num_patches_h = height // self.patch_size
num_patches_w = width // self.patch_size
device = pixel_values.device
device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
with torch.amp.autocast(device_type = device_type, enabled=False):
patch_coords = get_patches_center_coordinates(
num_patches_h, num_patches_w, dtype=torch.float32, device=device
)
self.inv_freq = self.inv_freq.to(device)
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
angles = angles.flatten(1, 2)
angles = angles.tile(2)
cos = torch.cos(angles)
sin = torch.sin(angles)
dtype = pixel_values.dtype
return cos.to(dtype=dtype), sin.to(dtype=dtype)
class DINOv3ViTEmbeddings(nn.Module):
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_size, device=device, dtype=dtype))
self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype))
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
self.patch_embeddings = operations.Conv2d(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
)
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None):
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embeddings.weight.dtype
patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
if bool_masked_pos is not None:
mask_token = self.mask_token.to(patch_embeddings.dtype)
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
cls_token = self.cls_token.expand(batch_size, -1, -1)
register_tokens = self.register_tokens.expand(batch_size, -1, -1)
device = patch_embeddings.device
cls_token = cls_token.to(device)
register_tokens = register_tokens.to(device)
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
return embeddings
class DINOv3ViTLayer(nn.Module):
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size, num_attention_heads,
device, dtype, operations):
super().__init__()
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
if use_gated_mlp:
self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations)
else:
self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations)
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.attention(
hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
)
hidden_states = self.layer_scale1(hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.layer_scale2(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class DINOv3ViTModel(nn.Module):
def __init__(self, config, dtype, device, operations):
super().__init__()
use_bf16 = comfy.model_management.should_use_bf16(device, prioritize_performance=True)
if dtype == torch.float16 and use_bf16:
dtype = torch.bfloat16
elif dtype == torch.float16 and not use_bf16:
dtype = torch.float32
num_hidden_layers = config["num_hidden_layers"]
hidden_size = config["hidden_size"]
num_attention_heads = config["num_attention_heads"]
num_register_tokens = config["num_register_tokens"]
intermediate_size = config["intermediate_size"]
layer_norm_eps = config["layer_norm_eps"]
num_channels = config["num_channels"]
patch_size = config["patch_size"]
rope_theta = config["rope_theta"]
self.embeddings = DINOv3ViTEmbeddings(
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size, dtype=dtype, device=device, operations=operations
)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
rope_theta, hidden_size, num_attention_heads, image_size=512, patch_size=patch_size, dtype=dtype, device=device
)
self.layer = nn.ModuleList(
[DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, mlp_bias=True,
intermediate_size=intermediate_size,num_attention_heads = num_attention_heads,
dtype=dtype, device=device, operations=operations)
for _ in range(num_hidden_layers)])
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: torch.Tensor | None = None,
**kwargs,
):
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
position_embeddings = self.rope_embeddings(pixel_values)
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(
hidden_states,
position_embeddings=position_embeddings,
)
if kwargs.get("skip_norm_elementwise", False):
sequence_output= F.layer_norm(hidden_states, hidden_states.shape[-1:])
else:
norm = self.norm.to(hidden_states.device)
sequence_output = norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return sequence_output, None, pooled_output, None

View File

@ -1,23 +0,0 @@
{
"model_type": "dinov3",
"hidden_size": 1024,
"image_size": 224,
"initializer_range": 0.02,
"intermediate_size": 4096,
"key_bias": false,
"layer_norm_eps": 1e-05,
"mlp_bias": true,
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"num_register_tokens": 4,
"patch_size": 16,
"pos_embed_rescale": 2.0,
"proj_bias": true,
"query_bias": true,
"rope_theta": 100.0,
"use_gated_mlp": false,
"value_bias": true,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@ -760,8 +760,6 @@ class Hunyuan3Dv2_1(LatentFormat):
latent_channels = 64
latent_dimensions = 1
class Trellis2(LatentFormat): # TODO
latent_channels = 32
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@ -1,282 +0,0 @@
import torch
import math
from comfy.ldm.modules.attention import optimized_attention
from typing import Tuple, Union, List
from comfy.ldm.trellis2.vae import VarLenTensor
import comfy.ops
# replica of the seedvr2 code
def var_attn_arg(kwargs):
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
max_seqlen_q = kwargs.get("max_seqlen_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q)
max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q)
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = True
if var_length:
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
if not skip_reshape:
# assumes 2D q, k,v [total_tokens, embed_dim]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
q = q.view(total_tokens, heads, head_dim)
k = k.view(k.shape[0], heads, head_dim)
v = v.view(v.shape[0], heads, head_dim)
b = q.size(0)
dim_head = q.shape[-1]
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
mask = None
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if mask is not None:
if mask.ndim == 2:
mask = mask.unsqueeze(0)
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if var_length:
return out.transpose(1, 2).values()
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
def scaled_dot_product_attention(*args, **kwargs):
num_all_args = len(args) + len(kwargs)
q = None
if num_all_args == 1:
qkv = args[0] if len(args) > 0 else kwargs.get('qkv')
elif num_all_args == 2:
q = args[0] if len(args) > 0 else kwargs.get('q')
kv = args[1] if len(args) > 1 else kwargs.get('kv')
elif num_all_args == 3:
q = args[0] if len(args) > 0 else kwargs.get('q')
k = args[1] if len(args) > 1 else kwargs.get('k')
v = args[2] if len(args) > 2 else kwargs.get('v')
if q is not None:
heads = q.shape[2]
else:
heads = qkv.shape[3]
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs)
out = out.permute(0, 2, 1, 3)
return out
def sparse_windowed_scaled_dot_product_self_attention(
qkv,
window_size: int,
shift_window: Tuple[int, int, int] = (0, 0, 0)
):
serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}'
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
if serialization_spatial_cache is None:
fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window)
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args))
else:
fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
heads = qkv_feats.shape[2]
if optimized_attention.__name__ == 'attention_xformers':
q, k, v = qkv_feats.unbind(dim=1)
q = q.unsqueeze(0) # [1, M, H, C]
k = k.unsqueeze(0) # [1, M, H, C]
v = v.unsqueeze(0) # [1, M, H, C]
#out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
elif optimized_attention.__name__ == 'attention_flash':
if 'flash_attn' not in globals():
import flash_attn
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
else:
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
out = out[bwd_indices] # [T, H, C]
return qkv.replace(out)
def calc_window_partition(
tensor,
window_size: Union[int, Tuple[int, ...]],
shift_window: Union[int, Tuple[int, ...]] = 0,
) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
DIM = tensor.coords.shape[1] - 1
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
shifted_coords = tensor.coords.clone().detach()
shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)]
NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
fwd_indices = torch.argsort(shifted_indices)
bwd_indices = torch.empty_like(fwd_indices)
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
seq_lens = torch.bincount(shifted_indices)
mask = seq_lens != 0
seq_lens = seq_lens[mask]
if optimized_attention.__name__ == 'attention_xformers':
if 'xops' not in globals():
import xformers.ops as xops
attn_func_args = {
'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
}
elif optimized_attention.__name__ == 'attention_flash':
attn_func_args = {
'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(),
'max_seqlen': torch.max(seq_lens)
}
return fwd_indices, bwd_indices, seq_lens, attn_func_args
def sparse_scaled_dot_product_attention(*args, **kwargs):
q=None
arg_names_dict = {
1: ['qkv'],
2: ['q', 'kv'],
3: ['q', 'k', 'v']
}
num_all_args = len(args) + len(kwargs)
for key in arg_names_dict[num_all_args][len(args):]:
assert key in kwargs, f"Missing argument {key}"
if num_all_args == 1:
qkv = args[0] if len(args) > 0 else kwargs['qkv']
device = qkv.device
s = qkv
q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
kv_seqlen = q_seqlen
qkv = qkv.feats # [T, 3, H, C]
elif num_all_args == 2:
q = args[0] if len(args) > 0 else kwargs['q']
kv = args[1] if len(args) > 1 else kwargs['kv']
device = q.device
if isinstance(q, VarLenTensor):
s = q
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
q = q.feats # [T_Q, H, C]
else:
s = None
N, L, H, C = q.shape
q_seqlen = [L] * N
q = q.reshape(N * L, H, C) # [T_Q, H, C]
if isinstance(kv, VarLenTensor):
kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
kv = kv.feats # [T_KV, 2, H, C]
else:
N, L, _, H, C = kv.shape
kv_seqlen = [L] * N
kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
elif num_all_args == 3:
q = args[0] if len(args) > 0 else kwargs['q']
k = args[1] if len(args) > 1 else kwargs['k']
v = args[2] if len(args) > 2 else kwargs['v']
device = q.device
if isinstance(q, VarLenTensor):
s = q
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
q = q.feats # [T_Q, H, Ci]
else:
s = None
N, L, H, CI = q.shape
q_seqlen = [L] * N
q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
if isinstance(k, VarLenTensor):
kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
k = k.feats # [T_KV, H, Ci]
v = v.feats # [T_KV, H, Co]
else:
N, L, H, CI, CO = *k.shape, v.shape[-1]
kv_seqlen = [L] * N
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
# TODO: change
if q is not None:
heads = q
else:
heads = qkv
heads = heads.shape[2]
if optimized_attention.__name__ == 'attention_xformers':
if 'xops' not in globals():
import xformers.ops as xops
if num_all_args == 1:
q, k, v = qkv.unbind(dim=1)
elif num_all_args == 2:
k, v = kv.unbind(dim=1)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
out = xops.memory_efficient_attention(q, k, v, mask)[0]
elif optimized_attention.__name__ == 'attention_flash':
if 'flash_attn' not in globals():
import flash_attn
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
if num_all_args in [2, 3]:
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
if num_all_args == 1:
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
elif num_all_args == 2:
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
elif num_all_args == 3:
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
elif optimized_attention.__name__ == "attention_pytorch":
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
if num_all_args in [2, 3]:
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
else:
cu_seqlens_kv = cu_seqlens_q
if num_all_args == 1:
q, k, v = qkv.unbind(dim=1)
elif num_all_args == 2:
k, v = kv.unbind(dim=1)
out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen),
skip_reshape=True, skip_output_reshape=True)
if s is not None:
return s.replace(out)
else:
return out.reshape(N, L, H, -1)

View File

@ -1,298 +0,0 @@
# will contain every cuda -> pytorch operation
from typing import Optional, Tuple
import torch
UINT32_SENTINEL = 0xFFFFFFFF
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
"""Kernel spatial offsets in the same order as the CUDA/Triton kernels."""
offsets = []
for vx in range(Kw):
for vy in range(Kh):
for vz in range(Kd):
offsets.append((vx * Dw, vy * Dh, vz * Dd))
return torch.tensor(offsets, device=device, dtype=torch.int32)
class TorchHashMap:
"""Sorted-array hashmap backed by torch.searchsorted."""
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
device = keys.device
self.sorted_keys, order = torch.sort(keys.to(torch.long))
self.sorted_vals = values.to(torch.long)[order]
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
self._n = self.sorted_keys.numel()
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
flat = flat_keys.to(torch.long)
if self._n == 0:
return torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32)
idx = torch.searchsorted(self.sorted_keys, flat)
idx_safe = torch.clamp(idx, max=self._n - 1)
found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat)
out = torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32)
if found.any():
out[found] = self.sorted_vals[idx_safe[found]].to(torch.int32)
return out
def build_submanifold_neighbor_map(
hashmap,
coords: torch.Tensor,
W, H, D,
Kw, Kh, Kd,
Dw, Dh, Dd,
):
device = coords.device
M = coords.shape[0]
V = Kw * Kh * Kd
half_V = V // 2 + 1
INVALID = -1
# int32 neighbour map: 4 bytes/elem vs 8 bytes for int64
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.int32)
b = coords[:, 0].long()
x = coords[:, 1].long()
y = coords[:, 2].long()
z = coords[:, 3].long()
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
ox = x - (Kw // 2) * Dw
oy = y - (Kh // 2) * Dh
oz = z - (Kd // 2) * Dd
for v in range(half_V):
if v == half_V - 1:
# Center voxel always maps to itself
neighbor[:, v] = torch.arange(M, device=device, dtype=torch.int32)
continue
dx, dy, dz = offsets[v]
kx = ox + dx
ky = oy + dy
kz = oz + dz
valid = (
(kx >= 0) & (kx < W) &
(ky >= 0) & (ky < H) &
(kz >= 0) & (kz < D)
)
flat = (
b[valid] * (W * H * D) +
kx[valid] * (H * D) +
ky[valid] * D +
kz[valid]
)
if flat.numel() > 0:
found = hashmap.lookup_flat(flat)
idx_in_M = torch.where(valid)[0]
neighbor[idx_in_M, v] = found.to(torch.int32)
# BUG FIX: old code used found != hashmap.default_value which
# compared int32 -1 against int64 4294967295 → always True.
# We now explicitly check for valid indices.
valid_found_mask = found >= 0
if valid_found_mask.any():
src_points = idx_in_M[valid_found_mask]
dst_points = found[valid_found_mask].long()
neighbor[dst_points, V - 1 - v] = src_points.to(torch.int32)
return neighbor
def get_recommended_chunk_mem(
device=None,
safety_fraction: float = 0.4,
min_gb: float = 0.25,
max_gb: float = 8.0,
):
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(device)
if device.type == 'cuda':
try:
idx = device.index if device.index is not None else 0
free_bytes, total_bytes = torch.cuda.mem_get_info(idx)
free_gb = free_bytes / (1024 ** 3)
total_gb = total_bytes / (1024 ** 3)
recommended = free_gb * safety_fraction
result = max(min_gb, min(recommended, max_gb))
return result
except Exception:
try:
idx = device.index if device.index is not None else 0
total_gb = torch.cuda.get_device_properties(idx).total_memory / (1024 ** 3)
except Exception:
total_gb = 16.0
if total_gb < 12:
result = 0.5
elif total_gb < 16:
result = 0.75
elif total_gb < 24:
result = 1.0
elif total_gb < 32:
result = 2.0
elif total_gb < 48:
result = 4.0
else:
result = 6.0
return result
else:
try:
import psutil
avail_gb = psutil.virtual_memory().available / (1024 ** 3)
recommended = avail_gb * safety_fraction
result = max(min_gb, min(recommended, max_gb))
return result
except ImportError:
return min_gb
def sparse_submanifold_conv3d(
feats: torch.Tensor,
coords: torch.Tensor,
shape: tuple,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
neighbor_cache: Optional[torch.Tensor],
dilation: tuple,
max_chunk_mem_gb: float = 6.0,
accumulate_f32: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if feats.shape[0] == 0:
Co = weight.shape[0]
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
if len(shape) == 5:
_, _, W, H, D = shape
else:
W, H, D = shape
Co, Kw, Kh, Kd, Ci = weight.shape
V = Kw * Kh * Kd
device = feats.device
sentinel = -1
max_chunk_mem_gb = get_recommended_chunk_mem(device)
if neighbor_cache is None:
b_stride = W * H * D
x_stride = H * D
y_stride = D
z_stride = 1
flat_keys = (coords[:, 0].long() * b_stride +
coords[:, 1].long() * x_stride +
coords[:, 2].long() * y_stride +
coords[:, 3].long() * z_stride)
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device)
hashmap = TorchHashMap(flat_keys, vals, UINT32_SENTINEL)
neighbor = build_submanifold_neighbor_map(
hashmap, coords, W, H, D, Kw, Kh, Kd,
dilation[0], dilation[1], dilation[2]
)
else:
neighbor = neighbor_cache
N_pts = feats.shape[0]
if accumulate_f32:
weight_T = weight.view(Co, V * Ci).to(torch.float32).T.contiguous()
output = torch.zeros(N_pts, Co, device=device, dtype=torch.float32)
else:
weight_T = weight.view(Co, V * Ci).to(feats.dtype).T.contiguous()
output = torch.zeros(N_pts, Co, device=device, dtype=feats.dtype)
# ------------------------------------------------------------------
# Chunk size from memory budget
# ------------------------------------------------------------------
bytes_per_elem = 4 if accumulate_f32 else feats.element_size()
mem_per_row = V * Ci * bytes_per_elem
max_chunk_mem = max_chunk_mem_gb * (1024 ** 3)
chunk_size = max(1, int(max_chunk_mem / mem_per_row))
chunk_size = min(chunk_size, N_pts)
# ------------------------------------------------------------------
# Chunked forward pass
# Each iteration:
# 1. gather (chunk, V, Ci) memory bound
# 2. mask zero invalids in-place, no extra alloc
# 3. reshape (chunk, V*Ci)
# 4. GEMM (chunk, V*Ci) @ (V*Ci, Co) → (chunk, Co) cuBLAS
# written directly into output slice via out= argument
# ------------------------------------------------------------------
for start in range(0, N_pts, chunk_size):
end = min(start + chunk_size, N_pts)
actual_chunk = end - start
# (chunk, V) int32
chunk_neighbor = neighbor[start:end]
chunk_valid = chunk_neighbor != sentinel
# Clamp sentinel -1 → 0 for safe indexing. No clone of the full map.
chunk_idx = chunk_neighbor.clamp(min=0).long()
# Gather: (chunk, V, Ci). Memory-bound, single index_select.
gathered = feats[chunk_idx]
# Zero invalid neighbours in-place. gathered is a fresh tensor from
# advanced indexing, so in-place mutation is safe.
gathered.mul_(chunk_valid.unsqueeze(-1))
# Reshape to (chunk, V*Ci)
gathered_flat = gathered.view(actual_chunk, V * Ci)
if accumulate_f32:
gathered_flat = gathered_flat.to(torch.float32)
# Single GEMM call per chunk, written directly into output.
# This avoids allocating a temporary (chunk, Co) tensor.
torch.matmul(gathered_flat, weight_T, out=output[start:end])
if accumulate_f32:
output = output.to(feats.dtype)
if bias is not None:
output = output + bias.unsqueeze(0).to(output.dtype)
return output, neighbor
class Mesh:
def __init__(self,
vertices,
faces,
vertex_attrs=None
):
self.vertices = vertices.float()
self.faces = faces.int()
self.vertex_attrs = vertex_attrs
@property
def device(self):
return self.vertices.device
def to(self, device, non_blocking=False):
return Mesh(
self.vertices.to(device, non_blocking=non_blocking),
self.faces.to(device, non_blocking=non_blocking),
self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None,
)
def cuda(self, non_blocking=False):
return self.to('cuda', non_blocking=non_blocking)
def cpu(self):
return self.to('cpu')

View File

@ -1,935 +0,0 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
from typing import Optional, Tuple, Literal, Union, List
from comfy.ldm.trellis2.attention import (
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
)
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
from comfy.ldm.flux.math import apply_rope, apply_rope1
class SparseGELU(nn.GELU):
def forward(self, input: VarLenTensor) -> VarLenTensor:
return input.replace(super().forward(input.feats))
class SparseFeedForwardNet(nn.Module):
def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
SparseLinear(channels, int(channels * mlp_ratio), device=device, dtype=dtype, operations=operations),
SparseGELU(approximate="tanh"),
SparseLinear(int(channels * mlp_ratio), channels, device=device, dtype=dtype, operations=operations),
)
def forward(self, x: VarLenTensor) -> VarLenTensor:
return self.mlp(x)
def manual_cast(obj, dtype):
return obj.to(dtype=dtype)
class LayerNorm32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x = manual_cast(x, torch.float32)
o = super().forward(x)
return manual_cast(o, x_dtype)
class SparseMultiHeadRMSNorm(nn.Module):
def __init__(self, dim: int, heads: int, device, dtype):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
x_type = x.dtype
x = x.float()
if isinstance(x, VarLenTensor):
x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale)
else:
x = F.normalize(x, dim=-1) * self.gamma * self.scale
return x.to(x_type)
class SparseRotaryPositionEmbedder(nn.Module):
def __init__(
self,
head_dim: int,
dim: int = 3,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
device=None
):
super().__init__()
self.head_dim = head_dim
self.dim = dim
self.rope_freq = rope_freq
self.freq_dim = head_dim // 2 // dim
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32, device=device) / self.freq_dim
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor:
phases_list = []
for i in range(self.dim):
phases_list.append(torch.outer(coords[..., i], self.freqs.to(coords.device)))
phases = torch.cat(phases_list, dim=-1)
if phases.shape[-1] < self.head_dim // 2:
padn = self.head_dim // 2 - phases.shape[-1]
phases = torch.cat([phases, torch.zeros(*phases.shape[:-1], padn, device=phases.device)], dim=-1)
cos = torch.cos(phases)
sin = torch.sin(phases)
f_cis_0 = torch.stack([cos, sin], dim=-1)
f_cis_1 = torch.stack([-sin, cos], dim=-1)
freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1)
return freqs_cis
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
self.freqs = self.freqs.to(indices.device)
phases = torch.outer(indices, self.freqs)
phases = torch.polar(torch.ones_like(phases), phases)
return phases
def forward(self, q, k=None):
cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}'
freqs_cis = q.get_spatial_cache(cache_name)
if freqs_cis is None:
coords = q.coords[..., 1:].to(torch.float32)
freqs_cis = self._get_freqs_cis(coords)
q.register_spatial_cache(cache_name, freqs_cis)
if q.feats.ndim == 3:
f_cis = freqs_cis.unsqueeze(1)
else:
f_cis = freqs_cis
if k is None:
return q.replace(apply_rope1(q.feats, f_cis))
q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis)
return q.replace(q_feats), k.replace(k_feats)
@staticmethod
def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_rotated = x_complex * phases.unsqueeze(-2)
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
return x_embed
class RotaryPositionEmbedder(SparseRotaryPositionEmbedder):
def forward(self, indices: torch.Tensor) -> torch.Tensor:
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
if torch.is_complex(phases):
phases = phases.to(torch.complex64)
else:
phases = phases.to(torch.float32)
if phases.shape[-1] < self.head_dim // 2:
padn = self.head_dim // 2 - phases.shape[-1]
phases = torch.cat([phases, torch.polar(
torch.ones(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32),
torch.zeros(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32)
)], dim=-1)
return phases
class SparseMultiHeadAttention(nn.Module):
def __init__(
self,
channels: int,
num_heads: int,
ctx_channels: Optional[int] = None,
type: Literal["self", "cross"] = "self",
attn_mode: Literal["full", "windowed", "double_windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
qkv_bias: bool = True,
use_rope: bool = False,
rope_freq: Tuple[int, int] = (1.0, 10000.0),
qk_rms_norm: bool = False,
device=None, dtype=None, operations=None
):
super().__init__()
self.channels = channels
self.head_dim = channels // num_heads
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
self.num_heads = num_heads
self._type = type
self.attn_mode = attn_mode
self.window_size = window_size
self.shift_window = shift_window
self.use_rope = use_rope
self.qk_rms_norm = qk_rms_norm
if self._type == "self":
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, device=device, dtype=dtype)
else:
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype)
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
if self.qk_rms_norm:
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
if use_rope:
self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq, device=device)
@staticmethod
def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
if isinstance(x, VarLenTensor):
return x.replace(module(x.feats))
else:
return module(x)
@staticmethod
def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]:
if isinstance(x, VarLenTensor):
return x.reshape(*shape)
else:
return x.reshape(*x.shape[:2], *shape)
def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]:
if isinstance(x, VarLenTensor):
x_feats = x.feats.unsqueeze(0)
else:
x_feats = x
x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor:
if self._type == "self":
dtype = next(self.to_qkv.parameters()).dtype
x = x.to(dtype)
qkv = self._linear(self.to_qkv, x)
qkv = self._fused_pre(qkv, num_fused=3)
if self.qk_rms_norm or self.use_rope:
q, k, v = qkv.unbind(dim=-3)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
if self.use_rope:
q, k = self.rope(q, k)
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
if self.attn_mode == "full":
h = sparse_scaled_dot_product_attention(qkv)
elif self.attn_mode == "windowed":
h = sparse_windowed_scaled_dot_product_self_attention(
qkv, self.window_size, shift_window=self.shift_window
)
elif self.attn_mode == "double_windowed":
qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:])
qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2])
h0 = sparse_windowed_scaled_dot_product_self_attention(
qkv0, self.window_size, shift_window=(0, 0, 0)
)
h1 = sparse_windowed_scaled_dot_product_self_attention(
qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3)
)
h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1))
else:
q = self._linear(self.to_q, x)
q = self._reshape_chs(q, (self.num_heads, -1))
dtype = next(self.to_kv.parameters()).dtype
context = context.to(dtype)
kv = self._linear(self.to_kv, context)
kv = self._fused_pre(kv, num_fused=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k, v = kv.unbind(dim=-3)
k = self.k_rms_norm(k)
h = sparse_scaled_dot_product_attention(q, k, v)
else:
h = sparse_scaled_dot_product_attention(q, kv)
h = self._reshape_chs(h, (-1,))
h = self._linear(self.to_out, h)
return h
class ModulatedSparseTransformerCrossBlock(nn.Module):
"""
Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
"""
def __init__(
self,
channels: int,
ctx_channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "swin"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
share_mod: bool = False,
device=None, dtype=None, operations=None
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
self.self_attn = SparseMultiHeadAttention(
channels,
num_heads=num_heads,
type="self",
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
device=device, dtype=dtype, operations=operations
)
self.cross_attn = SparseMultiHeadAttention(
channels,
ctx_channels=ctx_channels,
num_heads=num_heads,
type="cross",
attn_mode="full",
qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross,
device=device, dtype=dtype, operations=operations
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
device=device, dtype=dtype, operations=operations
)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype)
)
else:
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = x.replace(self.norm1(x.feats))
h = h * (1 + scale_msa) + shift_msa
h = self.self_attn(h)
h = h * gate_msa
x = x + h
h = x.replace(self.norm2(x.feats))
h = self.cross_attn(h, context)
x = x + h
h = x.replace(self.norm3(x.feats))
h = h * (1 + scale_mlp) + shift_mlp
h = self.mlp(h)
h = h * gate_mlp
x = x + h
return x
def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
return self._forward(x, mod, context)
class SLatFlowModel(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
model_channels: int,
cond_channels: int,
out_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
pe_mode: Literal["ape", "rope"] = "rope",
rope_freq: Tuple[float, float] = (1.0, 10000.0),
use_checkpoint: bool = False,
share_mod: bool = False,
initialization: str = 'vanilla',
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
dtype = None,
device = None,
operations = None,
):
super().__init__()
self.resolution = resolution
self.in_channels = in_channels
self.model_channels = model_channels
self.cond_channels = cond_channels
self.out_channels = out_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.pe_mode = pe_mode
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.initialization = initialization
self.qk_rms_norm = qk_rms_norm
self.qk_rms_norm_cross = qk_rms_norm_cross
self.dtype = dtype
self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations)
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype)
)
self.input_layer = SparseLinear(in_channels, model_channels, device=device, dtype=dtype, operations=operations)
self.blocks = nn.ModuleList([
ModulatedSparseTransformerCrossBlock(
model_channels,
cond_channels,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
attn_mode='full',
use_checkpoint=self.use_checkpoint,
use_rope=(pe_mode == "rope"),
rope_freq=rope_freq,
share_mod=self.share_mod,
qk_rms_norm=self.qk_rms_norm,
qk_rms_norm_cross=self.qk_rms_norm_cross,
device=device, dtype=dtype, operations=operations
)
for _ in range(num_blocks)
])
self.out_layer = SparseLinear(model_channels, out_channels, device=device, dtype=dtype, operations=operations)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def forward(
self,
x: SparseTensor,
t: torch.Tensor,
cond: Union[torch.Tensor, List[torch.Tensor]],
concat_cond: Optional[SparseTensor] = None,
**kwargs
) -> SparseTensor:
if concat_cond is not None:
x = sparse_cat([x, concat_cond], dim=-1)
if isinstance(cond, list):
cond = VarLenTensor.from_tensor_list(cond)
dtype = next(self.input_layer.parameters()).dtype
x = x.to(dtype)
h = self.input_layer(x)
h = manual_cast(h, self.dtype)
t = t.to(dtype)
t_embedder = self.t_embedder.to(dtype)
t_emb = t_embedder(t, out_dtype = t.dtype)
if self.share_mod:
t_emb = self.adaLN_modulation(t_emb)
t_emb = manual_cast(t_emb, self.dtype)
cond = manual_cast(cond, self.dtype)
for block in self.blocks:
h = block(h, t_emb, cond)
h = manual_cast(h, x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.out_layer(h)
return h
class FeedForwardNet(nn.Module):
def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(channels, int(channels * mlp_ratio), device=device, dtype=dtype),
nn.GELU(approximate="tanh"),
operations.Linear(int(channels * mlp_ratio), channels, device=device, dtype=dtype),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class MultiHeadRMSNorm(nn.Module):
def __init__(self, dim: int, heads: int, device=None, dtype=None):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels: int,
num_heads: int,
ctx_channels: Optional[int]=None,
type: Literal["self", "cross"] = "self",
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
qkv_bias: bool = True,
use_rope: bool = False,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
qk_rms_norm: bool = False,
device=None, dtype=None, operations=None
):
super().__init__()
self.channels = channels
self.head_dim = channels // num_heads
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
self.num_heads = num_heads
self._type = type
self.attn_mode = attn_mode
self.window_size = window_size
self.shift_window = shift_window
self.use_rope = use_rope
self.qk_rms_norm = qk_rms_norm
if self._type == "self":
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
else:
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype)
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
if self.qk_rms_norm:
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
B, L, C = x.shape
if self._type == "self":
x = x.to(next(self.to_qkv.parameters()).dtype)
qkv = self.to_qkv(x)
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
if self.attn_mode == "full":
if self.qk_rms_norm or self.use_rope:
q, k, v = qkv.unbind(dim=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
if self.use_rope:
assert phases is not None, "Phases must be provided for RoPE"
q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases)
k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases)
h = scaled_dot_product_attention(q, k, v)
else:
h = scaled_dot_product_attention(qkv)
else:
Lkv = context.shape[1]
q = self.to_q(x)
context = context.to(next(self.to_kv.parameters()).dtype)
kv = self.to_kv(context)
q = q.reshape(B, L, self.num_heads, -1)
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k, v = kv.unbind(dim=2)
k = self.k_rms_norm(k)
h = scaled_dot_product_attention(q, k, v)
else:
h = scaled_dot_product_attention(q, kv)
h = h.reshape(B, L, -1)
h = self.to_out(h)
return h
class ModulatedTransformerCrossBlock(nn.Module):
def __init__(
self,
channels: int,
ctx_channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[int, int] = (1.0, 10000.0),
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
share_mod: bool = False,
device=None, dtype=None, operations=None
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
self.self_attn = MultiHeadAttention(
channels,
num_heads=num_heads,
type="self",
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
device=device, dtype=dtype, operations=operations
)
self.cross_attn = MultiHeadAttention(
channels,
ctx_channels=ctx_channels,
num_heads=num_heads,
type="cross",
attn_mode="full",
qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross,
device=device, dtype=dtype, operations=operations
)
self.mlp = FeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
device=device, dtype=dtype, operations=operations
)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)
)
else:
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = self.norm1(x)
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
h = self.self_attn(h, phases=phases)
h = h * gate_msa.unsqueeze(1)
x = x + h
h = self.norm2(x)
h = self.cross_attn(h, context)
x = x + h
h = self.norm3(x)
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
h = self.mlp(h)
h = h * gate_mlp.unsqueeze(1)
x = x + h
return x
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
return self._forward(x, mod, context, phases)
class SparseStructureFlowModel(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
model_channels: int,
cond_channels: int,
out_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
pe_mode: Literal["ape", "rope"] = "rope",
rope_freq: Tuple[float, float] = (1.0, 10000.0),
use_checkpoint: bool = False,
share_mod: bool = False,
initialization: str = 'vanilla',
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
operations=None,
device = None,
dtype = torch.float32,
**kwargs
):
super().__init__()
self.device = device
self.resolution = resolution
self.in_channels = in_channels
self.model_channels = model_channels
self.cond_channels = cond_channels
self.out_channels = out_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.pe_mode = pe_mode
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.initialization = initialization
self.qk_rms_norm = qk_rms_norm
self.qk_rms_norm_cross = qk_rms_norm_cross
self.dtype = dtype
self.device = device
self.t_embedder = TimestepEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype)
)
pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3, device=device)
coords = torch.meshgrid(*[torch.arange(res, device=self.device, dtype=dtype) for res in [resolution] * 3], indexing='ij')
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
rope_phases = pos_embedder(coords)
self.register_buffer("rope_phases", rope_phases, persistent=False)
if pe_mode != "rope":
self.rope_phases = None
self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype)
self.blocks = nn.ModuleList([
ModulatedTransformerCrossBlock(
model_channels,
cond_channels,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
attn_mode='full',
use_checkpoint=self.use_checkpoint,
use_rope=(pe_mode == "rope"),
rope_freq=rope_freq,
share_mod=share_mod,
qk_rms_norm=self.qk_rms_norm,
qk_rms_norm_cross=self.qk_rms_norm_cross,
device=device, dtype=dtype, operations=operations
)
for _ in range(num_blocks)
])
self.out_layer = operations.Linear(model_channels, out_channels, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
h = h.to(next(self.input_layer.parameters()).dtype)
h = self.input_layer(h)
t_emb = self.t_embedder(t, out_dtype = t.dtype)
if self.share_mod:
t_emb = self.adaLN_modulation(t_emb)
t_emb = manual_cast(t_emb, self.dtype)
h = manual_cast(h, self.dtype)
cond = manual_cast(cond, self.dtype)
for block in self.blocks:
h = block(h, t_emb, cond, self.rope_phases)
h = manual_cast(h, x.dtype)
h = F.layer_norm(h, h.shape[-1:])
h = h.to(next(self.out_layer.parameters()).dtype)
h = self.out_layer(h)
h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous()
return h
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
t_shifted = t_shifted / 1000.0
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
t_new *= 1000.0
return t_new
class Trellis2(nn.Module):
def __init__(self, resolution,
in_channels = 32,
out_channels = 32,
model_channels = 1536,
cond_channels = 1024,
num_blocks = 30,
num_heads = 12,
mlp_ratio = 5.3334,
share_mod = True,
qk_rms_norm = True,
qk_rms_norm_cross = True,
init_txt_model=False, # for now
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
operations = operations or nn
# for some reason it passes num_heads = -1
if num_heads == -1:
num_heads = 12
args = {
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
}
txt_only = kwargs.get("txt_only", False)
if not txt_only:
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
self.shape2txt = None
if init_txt_model:
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args)
args.pop("out_channels")
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
else:
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
self.guidance_interval = [0.6, 1.0]
self.guidance_interval_txt = [0.6, 0.9]
def forward(self, x, timestep, context, **kwargs):
transformer_options = kwargs.get("transformer_options", {})
model_options = {}
if hasattr(self, "meta"):
model_options = self.meta
timestep = timestep.to(x.dtype)
embeds = kwargs.get("embeds")
if embeds is None:
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
is_1024 = True#self.img2shape.resolution == 1024
coords = model_options.get("coords", None)
coord_counts = model_options.get("coord_counts", None)
mode = model_options.get("generation_mode", "structure_generation")
is_512_run = False
if mode == "shape_generation_512":
is_512_run = True
mode = "shape_generation"
if coords is not None:
if x.ndim == 4:
x = x.squeeze(-1).transpose(1, 2)
not_struct_mode = True
else:
mode = "structure_generation"
not_struct_mode = False
if x.size(-1) == 16 and x.size(-2) == 16:
mode = "structure_generation"
not_struct_mode = False
if not not_struct_mode:
bsz = x.size(0)
x = x[:, :8]
x = x.view(bsz, 8, 16, 16, 16)
if is_1024 and not_struct_mode and not is_512_run:
context = embeds
sigmas = transformer_options.get("sigmas")[0].item()
if sigmas < 1.00001:
timestep *= 1000.0
if context.size(0) > 1:
cond = context.chunk(2)[1]
else:
cond = context
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
if not_struct_mode:
orig_bsz = x.shape[0]
rule = txt_rule if mode == "texture_generation" else shape_rule
# CFG Bypass Slicing
if rule and orig_bsz > 1:
half = orig_bsz // 2
x_eval = x[half:]
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
c_eval = cond
else:
x_eval = x
t_eval = timestep
c_eval = context
B, N, C = x_eval.shape
# Vectorized SparseTensor Construction
if mode in ["shape_generation", "texture_generation"]:
if coord_counts is not None:
logical_batch = coord_counts.shape[0]
# Duplicate coords if CFG is active
if B > logical_batch:
c_pos = coords.clone()
c_pos[:, 0] += logical_batch
batched_coords = torch.cat([coords, c_pos], dim=0)
counts_eval = torch.cat([coord_counts, coord_counts], dim=0)
else:
batched_coords = coords
counts_eval = coord_counts
# Create boolean mask [B, N] to drop the padded zeros instantly
mask = torch.arange(N, device=x.device).unsqueeze(0) < counts_eval.unsqueeze(1)
feats_flat = x_eval[mask]
else:
feats_flat = x_eval.reshape(-1, C)
coords_list =[]
for i in range(B):
c = coords.clone()
c[:, 0] = i
coords_list.append(c)
batched_coords = torch.cat(coords_list, dim=0)
mask = None
else:
batched_coords = coords
feats_flat = x_eval
mask = None
x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
if mode == "shape_generation":
if is_512_run:
out = self.img2shape_512(x_st, t_eval, c_eval)
else:
out = self.img2shape(x_st, t_eval, c_eval)
elif mode == "texture_generation":
if self.shape2txt is None:
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
slat = model_options.get("shape_slat")
if slat is None:
raise ValueError("shape_slat can't be None")
slat_feats = slat
# Duplicate shape context if CFG is active
if coord_counts is not None and B > coord_counts.shape[0]:
slat_feats = torch.cat([slat_feats, slat_feats], dim=0)
elif coord_counts is None:
slat_feats = slat_feats[:N].repeat(B, 1)
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats.to(x_st.feats.device)], dim=-1))
out = self.shape2txt(x_st, t_eval, c_eval)
else: # structure
orig_bsz = x.shape[0]
if shape_rule and orig_bsz > 1:
half = orig_bsz // 2
x_eval = x[half:]
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
out = self.structure_model(x_eval, t_eval, cond)
out = out.repeat(2, 1, 1, 1, 1)
else:
out = self.structure_model(x, timestep, context)
if not_struct_mode:
if mask is not None:
# Instantly scatter the valid tokens back into a padded rectangular tensor
padded_out = torch.zeros((B, N, out.feats.shape[-1]), device=x.device, dtype=out.feats.dtype)
padded_out[mask] = out.feats
out_tensor = padded_out.transpose(1, 2).unsqueeze(-1)
else:
out_tensor = out.feats.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
if rule and orig_bsz > 1:
out_tensor = out_tensor.repeat(2, 1, 1, 1)
return out_tensor
else:
out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 0, 24))
return out

File diff suppressed because it is too large Load Diff

View File

@ -53,7 +53,6 @@ import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.ldm.trellis2.model
import comfy.ldm.ace.ace_step15
import comfy.ldm.cogvideo.model
import comfy.ldm.rt_detr.rtdetr_v4
@ -1638,16 +1637,6 @@ class WAN22(WAN21):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class Trellis2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.trellis2.model.Trellis2):
super().__init__(model_config, model_type, device, unet_model)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
embeds = kwargs.get("embeds")
out["embeds"] = comfy.conds.CONDRegular(embeds)
return out
class WAN21_FlowRVS(WAN21):
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
model_config.unet_config["model_type"] = "t2v"
@ -1689,6 +1678,7 @@ class WAN21_SCAIL(WAN21):
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
return out
class WAN22_WanDancer(WAN21):

View File

@ -113,30 +113,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
return unet_config
if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
unet_config = {}
unet_config["image_model"] = "trellis2"
unet_config["init_txt_model"] = False
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
unet_config["init_txt_model"] = True
unet_config["resolution"] = 64
if metadata is not None:
if "is_512" in metadata:
unet_config["resolution"] = 32
unet_config["num_heads"] = 12
return unet_config
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture
unet_config = {}
unet_config["image_model"] = "trellis2"
unet_config["resolution"] = 64
unet_config["num_heads"] = 12
unet_config["txt_only"] = True
return unet_config
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
unet_config = {}
unet_config["audio_model"] = "dit1.0"

View File

@ -1613,6 +1613,16 @@ class ModelPatcherDynamic(ModelPatcher):
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
def restore_loaded_backups(self):
restored = self.model.model_loaded_weight_memory
for key in list(self.backup.keys()):
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
for key in list(self.backup_buffers.keys()):
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
self.model.model_loaded_weight_memory = 0
return restored
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
@ -1629,7 +1639,7 @@ class ModelPatcherDynamic(ModelPatcher):
num_patches = 0
allocated_size = 0
self.model.model_loaded_weight_memory = 0
self.restore_loaded_backups()
with self.use_ejected():
self.unpatch_hooks()
@ -1716,6 +1726,9 @@ class ModelPatcherDynamic(ModelPatcher):
force_load=True
if force_load:
if hasattr(m, "_v"):
comfy_aimdo.model_vbar.vbar_unpin(m._v)
delattr(m, "_v")
force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to)
else:
@ -1773,13 +1786,7 @@ class ModelPatcherDynamic(ModelPatcher):
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
if freed < memory_to_free:
for key in list(self.backup.keys()):
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
for key in list(self.backup_buffers.keys()):
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
freed += self.model.model_loaded_weight_memory
self.model.model_loaded_weight_memory = 0
freed += self.restore_loaded_backups()
return freed

View File

@ -15,7 +15,6 @@ import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.lightricks.vae.audio_vae
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.trellis2.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
@ -529,18 +528,6 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd or "txt_dec.blocks.3.4.conv2.weight" in sd: # trellis2 or trellis2 texture only
init_txt_model = False
init_txt_model_only = False
if "shape_dec.blocks.1.16.to_subdiv.weight" not in sd:
init_txt_model_only = True
if "txt_dec.blocks.1.16.norm1.weight" in sd:
init_txt_model = True
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# TODO
self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model, init_txt_model_only= init_txt_model_only)
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}

View File

@ -1318,29 +1318,6 @@ class WAN22_T2V(WAN21_T2V):
out = model_base.WAN22(self, image_to_video=True, device=device)
return out
class Trellis2(supported_models_base.BASE):
unet_config = {
"image_model": "trellis2"
}
sampling_settings = {
"shift": 3.0,
}
memory_usage_factor = 3.5
latent_format = latent_formats.Trellis2
vae_key_prefix = ["vae."]
clip_vision_prefix = "conditioner.main_image_encoder.model."
# this is only needed for the texture model
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
return model_base.Trellis2(self, device=device)
def clip_target(self, state_dict={}):
return None
class WAN21_FlowRVS(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@ -1807,7 +1784,6 @@ class Kandinsky5Image(Kandinsky5):
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
class ACEStep15(supported_models_base.BASE):
unet_config = {
"audio_model": "ace1.5",
@ -1847,6 +1823,7 @@ class ACEStep15(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
class LongCatImage(supported_models_base.BASE):
unet_config = {
"image_model": "flux",
@ -1924,7 +1901,6 @@ class ErnieImage(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
class SAM3(supported_models_base.BASE):
unet_config = {"image_model": "SAM3"}
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@ -2044,6 +2020,7 @@ class CogVideoX_Inpaint(CogVideoX_T2V):
out = model_base.CogVideoX(self, image_to_video=True, device=device)
return out
models = [
LotusD,
Stable_Zero123,
@ -2130,5 +2107,4 @@ models = [
CogVideoX_I2V,
CogVideoX_T2V,
SVD_img2vid,
Trellis2
]

View File

@ -1019,10 +1019,11 @@ def bislerp(samples, width, height):
def lanczos(samples, width, height):
#the below API is strict and expects grayscale to be squeezed
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
if samples.ndim == 4:
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
images = [torch.from_numpy(t).movedim(-1, 0) if (t := np.array(image).astype(np.float32) / 255.0).ndim == 3 else torch.from_numpy(t) for image in images]
result = torch.stack(images)
return result.to(samples.device, samples.dtype)

View File

@ -7,10 +7,9 @@ import torch
class VOXEL:
def __init__(self, data: torch.Tensor, voxel_colors=None, resolution=None):
def __init__(self, data: torch.Tensor):
self.data = data
self.voxel_colors = voxel_colors
self.resolution = resolution # each 3d model has its own resolution
class MESH:
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor,

View File

@ -35,6 +35,19 @@ class AnthropicMessage(BaseModel):
content: list[AnthropicTextContent | AnthropicImageContent] = Field(...)
class AnthropicThinkingConfig(BaseModel):
type: Literal["enabled", "disabled", "adaptive"] = Field(...)
budget_tokens: int | None = Field(
None, ge=1024,
description="Reasoning budget in tokens. Used when type is 'enabled'. Must be less than max_tokens.",
)
class AnthropicOutputConfig(BaseModel):
"""Used with `thinking.type='adaptive'` on models like Opus 4.7."""
effort: Literal["low", "medium", "high"] | None = Field(None)
class AnthropicMessagesRequest(BaseModel):
model: str = Field(...)
messages: list[AnthropicMessage] = Field(...)
@ -44,6 +57,8 @@ class AnthropicMessagesRequest(BaseModel):
top_p: float | None = Field(None, ge=0.0, le=1.0)
top_k: int | None = Field(None, ge=0)
stop_sequences: list[str] | None = Field(None)
thinking: AnthropicThinkingConfig | None = Field(None)
output_config: AnthropicOutputConfig | None = Field(None)
class AnthropicResponseTextBlock(BaseModel):
@ -51,6 +66,14 @@ class AnthropicResponseTextBlock(BaseModel):
text: str = Field(...)
class AnthropicResponseThinkingBlock(BaseModel):
type: Literal["thinking"] = "thinking"
thinking: str = Field(...)
AnthropicResponseBlock = AnthropicResponseTextBlock | AnthropicResponseThinkingBlock
class AnthropicCacheCreationUsage(BaseModel):
ephemeral_5m_input_tokens: int | None = Field(None)
ephemeral_1h_input_tokens: int | None = Field(None)
@ -69,7 +92,7 @@ class AnthropicMessagesResponse(BaseModel):
type: str | None = Field(None)
role: str | None = Field(None)
model: str | None = Field(None)
content: list[AnthropicResponseTextBlock] | None = Field(None)
content: list[AnthropicResponseBlock] | None = Field(None)
stop_reason: str | None = Field(None)
stop_sequence: str | None = Field(None)
usage: AnthropicMessagesUsage | None = Field(None)

View File

@ -0,0 +1,93 @@
"""Pydantic models for the OpenRouter chat completions API.
See: https://openrouter.ai/docs/api/api-reference/chat/send-chat-completion-request
"""
from typing import Literal
from pydantic import BaseModel, Field
class OpenRouterTextContent(BaseModel):
type: Literal["text"] = "text"
text: str = Field(...)
class OpenRouterImageUrl(BaseModel):
url: str = Field(...)
class OpenRouterImageContent(BaseModel):
type: Literal["image_url"] = "image_url"
image_url: OpenRouterImageUrl = Field(...)
class OpenRouterVideoUrl(BaseModel):
url: str = Field(...)
class OpenRouterVideoContent(BaseModel):
type: Literal["video_url"] = "video_url"
video_url: OpenRouterVideoUrl = Field(...)
OpenRouterContentBlock = OpenRouterTextContent | OpenRouterImageContent | OpenRouterVideoContent
class OpenRouterMessage(BaseModel):
role: Literal["system", "user", "assistant"] = Field(...)
content: str | list[OpenRouterContentBlock] = Field(...)
class OpenRouterReasoningConfig(BaseModel):
effort: str | None = Field(None)
exclude: bool | None = Field(None, description="If true, model reasons but reasoning is excluded from response.")
class OpenRouterWebSearchOptions(BaseModel):
search_context_size: str | None = Field(None)
class OpenRouterChatRequest(BaseModel):
model: str = Field(...)
messages: list[OpenRouterMessage] = Field(...)
seed: int | None = Field(None)
reasoning: OpenRouterReasoningConfig | None = Field(None)
web_search_options: OpenRouterWebSearchOptions | None = Field(None)
stream: bool = Field(False)
class OpenRouterUsage(BaseModel):
prompt_tokens: int | None = Field(None)
completion_tokens: int | None = Field(None)
total_tokens: int | None = Field(None)
cost: float | None = Field(None, description="Server-side authoritative USD cost of the call.")
class OpenRouterResponseMessage(BaseModel):
role: str | None = Field(None)
content: str | None = Field(None)
reasoning: str | None = Field(None)
refusal: str | None = Field(None)
class OpenRouterChoice(BaseModel):
index: int | None = Field(None)
message: OpenRouterResponseMessage | None = Field(None)
finish_reason: str | None = Field(None)
class OpenRouterError(BaseModel):
code: int | str | None = Field(None)
message: str | None = Field(None)
metadata: dict | None = Field(None)
class OpenRouterChatResponse(BaseModel):
id: str | None = Field(None)
model: str | None = Field(None)
object: str | None = Field(None)
provider: str | None = Field(None)
choices: list[OpenRouterChoice] | None = Field(None)
usage: OpenRouterUsage | None = Field(None)
error: OpenRouterError | None = Field(None)

View File

@ -9,8 +9,11 @@ from comfy_api_nodes.apis.anthropic import (
AnthropicMessage,
AnthropicMessagesRequest,
AnthropicMessagesResponse,
AnthropicOutputConfig,
AnthropicResponseTextBlock,
AnthropicRole,
AnthropicTextContent,
AnthropicThinkingConfig,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@ -32,15 +35,29 @@ CLAUDE_MODELS: dict[str, str] = {
"Haiku 4.5": "claude-haiku-4-5-20251001",
}
_THINKING_UNSUPPORTED = {"Haiku 4.5"}
# Models that use the newer "adaptive" thinking mode (Opus 4.7 requires it; older models keep the explicit budget API).
# Anthropic decides the actual budget when adaptive is used, based on the `output_config.effort` hint.
_ADAPTIVE_THINKING_MODELS = {"Opus 4.7", "Opus 4.6", "Sonnet 4.6"}
def _claude_model_inputs():
return [
# Budget mode (Sonnet 4.5): effort -> reasoning budget in tokens. Must be < max_tokens.
# Sized so even the "high" budget fits comfortably under the default max_tokens=32768.
_REASONING_BUDGET: dict[str, int] = {
"low": 2048,
"medium": 8192,
"high": 16384,
}
_REASONING_EFFORTS = ["off", "low", "medium", "high"]
def _claude_model_inputs(model_label: str):
inputs: list = [
IO.Int.Input(
"max_tokens",
default=16000,
min=32,
max=32000,
tooltip="Maximum number of tokens to generate before stopping.",
default=32768,
min=4096,
max=64000,
tooltip="Maximum number of tokens to generate (includes reasoning tokens when enabled).",
advanced=True,
),
IO.Float.Input(
@ -49,10 +66,24 @@ def _claude_model_inputs():
min=0.0,
max=1.0,
step=0.01,
tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random. Ignored for Opus 4.7.",
tooltip=(
"Controls randomness. 0.0 is deterministic, 1.0 is most random. "
"Ignored for Opus 4.7 and any model when reasoning_effort is set."
),
advanced=True,
),
]
if model_label not in _THINKING_UNSUPPORTED:
inputs.append(
IO.Combo.Input(
"reasoning_effort",
options=_REASONING_EFFORTS,
default="off",
tooltip="Extended thinking effort. 'off' disables reasoning.",
advanced=True,
)
)
return inputs
def _model_price_per_million(model: str) -> tuple[float, float] | None:
@ -95,7 +126,11 @@ def calculate_tokens_price(response: AnthropicMessagesResponse) -> float | None:
def _get_text_from_response(response: AnthropicMessagesResponse) -> str:
if not response.content:
return ""
return "\n".join(block.text for block in response.content if block.text)
# Thinking blocks are silently dropped — we never want reasoning in the output.
return "\n".join(
block.text for block in response.content
if isinstance(block, AnthropicResponseTextBlock) and block.text
)
async def _build_image_content_blocks(
@ -133,7 +168,10 @@ class ClaudeNode(IO.ComfyNode):
),
IO.DynamicCombo.Input(
"model",
options=[IO.DynamicCombo.Option(label, _claude_model_inputs()) for label in CLAUDE_MODELS],
options=[
IO.DynamicCombo.Option(label, _claude_model_inputs(label))
for label in CLAUDE_MODELS
],
tooltip="The Claude model used to generate the response.",
),
IO.Int.Input(
@ -207,8 +245,29 @@ class ClaudeNode(IO.ComfyNode):
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_label = model["model"]
max_tokens = model["max_tokens"]
temperature = None if model_label == "Opus 4.7" else model["temperature"]
max_tokens = model.get("max_tokens", 32768)
reasoning_effort = model.get("reasoning_effort", "off")
thinking_enabled = reasoning_effort not in ("off", None) and model_label not in _THINKING_UNSUPPORTED
# Anthropic requires temperature to be unset (defaults to 1.0) when thinking is enabled.
# Opus 4.7 also rejects user-supplied temperature.
if thinking_enabled or model_label == "Opus 4.7":
temperature = None
else:
temperature = model.get("temperature", 1.0)
thinking_cfg: AnthropicThinkingConfig | None = None
output_cfg: AnthropicOutputConfig | None = None
if thinking_enabled:
if model_label in _ADAPTIVE_THINKING_MODELS:
# Adaptive mode - Anthropic chooses the budget based on effort hint
thinking_cfg = AnthropicThinkingConfig(type="adaptive")
output_cfg = AnthropicOutputConfig(effort=reasoning_effort)
else:
# Budget mode (Sonnet 4.5). Leave at least 1024 tokens for the actual response
budget = _REASONING_BUDGET[reasoning_effort]
budget = min(budget, max(1024, max_tokens - 1024))
thinking_cfg = AnthropicThinkingConfig(type="enabled", budget_tokens=budget)
image_tensors: list[Input.Image] = [t for t in (images or {}).values() if t is not None]
if sum(get_number_of_images(t) for t in image_tensors) > CLAUDE_MAX_IMAGES:
@ -229,6 +288,8 @@ class ClaudeNode(IO.ComfyNode):
messages=[AnthropicMessage(role=AnthropicRole.user, content=content)],
system=system_prompt or None,
temperature=temperature,
thinking=thinking_cfg,
output_config=output_cfg,
),
price_extractor=calculate_tokens_price,
)

View File

@ -43,15 +43,16 @@ from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
download_url_to_video_output,
downscale_video_to_max_pixels,
get_number_of_images,
image_tensor_pair_to_batch,
poll_op,
resize_video_to_pixel_budget,
sync_op,
upload_audio_to_comfyapi,
upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
upscale_video_to_min_pixels,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
@ -110,12 +111,13 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st
max_px = limits.get("max")
if min_px and pixels < min_px:
raise ValueError(
f"Reference video {index} is too small: {w}x{h} = {pixels:,}px. " f"Minimum is {min_px:,}px for this model."
f"Reference video {index} is too small: {w}x{h} = {pixels:,} total pixels. "
f"Minimum for this model is {min_px:,} total pixels."
)
if max_px and pixels > max_px:
raise ValueError(
f"Reference video {index} is too large: {w}x{h} = {pixels:,}px. "
f"Maximum is {max_px:,}px for this model. Try downscaling the video."
f"Reference video {index} is too large: {w}x{h} = {pixels:,} total pixels. "
f"Maximum for this model is {max_px:,} total pixels. Try downscaling the video."
)
@ -1676,14 +1678,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
"first_frame_asset_id",
default="",
tooltip="Seedance asset_id to use as the first frame. "
"Mutually exclusive with the first_frame image input.",
"Mutually exclusive with the first_frame image input.",
optional=True,
),
IO.String.Input(
"last_frame_asset_id",
default="",
tooltip="Seedance asset_id to use as the last frame. "
"Mutually exclusive with the last_frame image input.",
"Mutually exclusive with the last_frame image input.",
optional=True,
),
IO.Int.Input(
@ -1865,11 +1867,20 @@ def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16
IO.Boolean.Input(
"auto_downscale",
default=False,
advanced=True,
optional=True,
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
),
IO.Boolean.Input(
"auto_upscale",
default=False,
advanced=True,
optional=True,
tooltip="Automatically upscale reference videos that are below the model's minimum pixel count "
"for the selected resolution. Aspect ratio is preserved; videos already meeting the minimum are "
"untouched. Note: upscaling a low-resolution source does not add real detail and may produce "
"lower-quality generations.",
),
IO.Autogrow.Input(
"reference_assets",
template=IO.Autogrow.TemplateNames(
@ -2030,7 +2041,13 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
max_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("max")
if max_px:
for key in reference_videos:
reference_videos[key] = resize_video_to_pixel_budget(reference_videos[key], max_px)
reference_videos[key] = downscale_video_to_max_pixels(reference_videos[key], max_px)
if model.get("auto_upscale") and reference_videos:
min_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("min")
if min_px:
for key in reference_videos:
reference_videos[key] = upscale_video_to_min_pixels(reference_videos[key], min_px)
total_video_duration = 0.0
for i, key in enumerate(reference_videos, 1):

View File

@ -0,0 +1,374 @@
"""API Nodes for OpenRouter LLM chat completions."""
from dataclasses import dataclass
from typing import Literal
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.openrouter import (
OpenRouterChatRequest,
OpenRouterChatResponse,
OpenRouterContentBlock,
OpenRouterImageContent,
OpenRouterImageUrl,
OpenRouterMessage,
OpenRouterReasoningConfig,
OpenRouterTextContent,
OpenRouterVideoContent,
OpenRouterVideoUrl,
OpenRouterWebSearchOptions,
)
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
sync_op,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
)
OPENROUTER_CHAT_ENDPOINT = "/proxy/openrouter/api/v1/chat/completions"
Profile = Literal["standard", "reasoning", "frontier_reasoning", "perplexity", "perplexity_reasoning"]
@dataclass(frozen=True)
class _ModelSpec:
slug: str # exact OpenRouter model id
profile: Profile
price_in: float # USD per token (prompt)
price_out: float # USD per token (completion)
max_images: int = 0 # 0 = no image input; otherwise max URL-passed images supported
max_videos: int = 0 # 0 = no video input; otherwise max URL-passed videos supported
MODELS: list[_ModelSpec] = [
_ModelSpec("anthropic/claude-opus-4.7", "frontier_reasoning", 0.000005, 0.000025, max_images=20),
_ModelSpec("openai/gpt-5.5-pro", "frontier_reasoning", 0.00003, 0.00018, max_images=20),
_ModelSpec("openai/gpt-5.5", "frontier_reasoning", 0.000005, 0.00003, max_images=20),
_ModelSpec("google/gemini-3.5-flash", "reasoning", 0.0000015, 0.000009, max_images=20, max_videos=4),
_ModelSpec("x-ai/grok-4.20", "reasoning", 0.00000125, 0.0000025, max_images=20),
_ModelSpec("x-ai/grok-4.3", "reasoning", 0.00000125, 0.0000025, max_images=20),
_ModelSpec("deepseek/deepseek-v4-pro", "reasoning", 0.000000435, 0.00000087),
_ModelSpec("deepseek/deepseek-v4-flash", "reasoning", 0.000000112, 0.000000224),
_ModelSpec("deepseek/deepseek-v3.2", "reasoning", 0.000000252, 0.000000378),
_ModelSpec("qwen/qwen3.6-max-preview", "reasoning", 0.00000104, 0.00000624),
_ModelSpec("qwen/qwen3.6-plus", "reasoning", 0.000000325, 0.00000195, max_images=10, max_videos=4),
_ModelSpec("qwen/qwen3.6-flash", "reasoning", 0.0000001875, 0.000001125, max_images=10, max_videos=4),
_ModelSpec("mistralai/mistral-large-2512", "standard", 0.0000005, 0.0000015, max_images=8),
_ModelSpec("mistralai/mistral-medium-3-5", "reasoning", 0.0000015, 0.0000075, max_images=8),
_ModelSpec("z-ai/glm-4.6", "reasoning", 0.00000043, 0.00000174),
_ModelSpec("z-ai/glm-5", "reasoning", 0.0000006, 0.00000192),
_ModelSpec("moonshotai/kimi-k2.6", "reasoning", 0.00000073, 0.00000349, max_images=10),
_ModelSpec("moonshotai/kimi-k2-thinking", "reasoning", 0.0000006, 0.0000025),
_ModelSpec("perplexity/sonar-pro", "perplexity", 0.000003, 0.000015),
_ModelSpec("perplexity/sonar-reasoning-pro", "perplexity_reasoning", 0.000002, 0.000008),
_ModelSpec("perplexity/sonar-deep-research", "perplexity_reasoning", 0.000002, 0.000008),
]
_MODELS_BY_SLUG: dict[str, _ModelSpec] = {m.slug: m for m in MODELS}
_REASONING_EFFORTS = ["off", "low", "medium", "high"]
_SEARCH_CONTEXT_SIZES = ["low", "medium", "high"]
def _reasoning_extra_inputs() -> list:
return [
IO.Combo.Input(
"reasoning_effort",
options=_REASONING_EFFORTS,
default="off",
tooltip="Reasoning effort. 'off' disables reasoning entirely.",
advanced=True,
),
]
def _perplexity_extra_inputs() -> list:
return [
IO.Combo.Input(
"search_context_size",
options=_SEARCH_CONTEXT_SIZES,
default="medium",
tooltip="How much web search context to retrieve. Larger = more grounded but slower/pricier.",
advanced=True,
),
]
def _profile_inputs(profile: Profile) -> list:
if profile == "standard":
return []
if profile in ("reasoning", "frontier_reasoning"):
return _reasoning_extra_inputs()
if profile == "perplexity":
return _perplexity_extra_inputs()
if profile == "perplexity_reasoning":
return _perplexity_extra_inputs() + _reasoning_extra_inputs()
raise ValueError(f"Unknown profile: {profile}")
def _media_inputs(spec: _ModelSpec) -> list:
extras: list = []
if spec.max_images > 0:
extras.append(
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, spec.max_images + 1)],
min=0,
),
tooltip=f"Optional reference image(s) — up to {spec.max_images}. Sent as URLs.",
)
)
if spec.max_videos > 0:
extras.append(
IO.Autogrow.Input(
"videos",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("video"),
names=[f"video_{i}" for i in range(1, spec.max_videos + 1)],
min=0,
),
tooltip=f"Optional reference video(s) — up to {spec.max_videos}. Sent as URLs.",
)
)
return extras
def _inputs_for_model(spec: _ModelSpec) -> list:
return _profile_inputs(spec.profile) + _media_inputs(spec)
def _build_model_options() -> list[IO.DynamicCombo.Option]:
return [IO.DynamicCombo.Option(spec.slug, _inputs_for_model(spec)) for spec in MODELS]
def _calculate_price(response: OpenRouterChatResponse) -> float | None:
if response.usage and response.usage.cost is not None:
return float(response.usage.cost)
return None
def _price_badge_jsonata() -> str:
rates_pairs = []
for spec in MODELS:
prompt_per_1k = spec.price_in * 1000
completion_per_1k = spec.price_out * 1000
rates_pairs.append(f' "{spec.slug}": [{prompt_per_1k:.8g}, {completion_per_1k:.8g}]')
rates_block = ",\n".join(rates_pairs)
return (
"(\n"
" $rates := {\n"
f"{rates_block}\n"
" };\n"
" $r := $lookup($rates, widgets.model);\n"
" $r ? {\n"
' "type": "list_usd",\n'
' "usd": $r,\n'
' "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }\n'
' } : {"type": "text", "text": "Token-based"}\n'
")"
)
async def _build_image_blocks(
cls: type[IO.ComfyNode], spec: _ModelSpec, images: list[Input.Image]
) -> list[OpenRouterImageContent]:
urls = await upload_images_to_comfyapi(
cls,
images,
max_images=spec.max_images,
total_pixels=2048 * 2048,
mime_type="image/png",
wait_label="Uploading reference images",
)
return [OpenRouterImageContent(image_url=OpenRouterImageUrl(url=url)) for url in urls]
async def _build_video_blocks(cls: type[IO.ComfyNode], videos: list[Input.Video]) -> list[OpenRouterVideoContent]:
blocks: list[OpenRouterVideoContent] = []
total = len(videos)
for idx, video in enumerate(videos):
label = "Uploading reference video"
if total > 1:
label = f"{label} ({idx + 1}/{total})"
url = await upload_video_to_comfyapi(cls, video, wait_label=label)
blocks.append(OpenRouterVideoContent(video_url=OpenRouterVideoUrl(url=url)))
return blocks
def _user_message(prompt: str, media_blocks: list[OpenRouterContentBlock]) -> OpenRouterMessage:
if not media_blocks:
return OpenRouterMessage(role="user", content=prompt)
blocks: list[OpenRouterContentBlock] = list(media_blocks)
blocks.append(OpenRouterTextContent(text=prompt))
return OpenRouterMessage(role="user", content=blocks)
def _build_messages(
system_prompt: str, prompt: str, media_blocks: list[OpenRouterContentBlock]
) -> list[OpenRouterMessage]:
messages: list[OpenRouterMessage] = []
if system_prompt:
messages.append(OpenRouterMessage(role="system", content=system_prompt))
messages.append(_user_message(prompt, media_blocks))
return messages
def _build_request(
slug: str,
system_prompt: str,
prompt: str,
media_blocks: list[OpenRouterContentBlock],
*,
seed: int,
reasoning_effort: str | None,
search_context_size: str | None,
) -> OpenRouterChatRequest:
reasoning_cfg: OpenRouterReasoningConfig | None = None
if reasoning_effort and reasoning_effort != "off":
# exclude=True asks providers to reason internally but not return the trace
reasoning_cfg = OpenRouterReasoningConfig(effort=reasoning_effort, exclude=True)
web_search_cfg: OpenRouterWebSearchOptions | None = None
if search_context_size:
web_search_cfg = OpenRouterWebSearchOptions(search_context_size=search_context_size)
return OpenRouterChatRequest(
model=slug,
messages=_build_messages(system_prompt, prompt, media_blocks),
seed=seed if seed > 0 else None,
reasoning=reasoning_cfg,
web_search_options=web_search_cfg,
)
def _extract_text(response: OpenRouterChatResponse) -> str:
if response.error:
code = response.error.code if response.error.code is not None else "unknown"
raise ValueError(f"OpenRouter error ({code}): {response.error.message or 'no message'}")
if not response.choices:
raise ValueError("Empty response from OpenRouter (no choices).")
message = response.choices[0].message
if not message:
raise ValueError("Empty response from OpenRouter (no message).")
if message.refusal:
raise ValueError(f"Model refused to respond: {message.refusal}")
return message.content or ""
class OpenRouterLLMNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="OpenRouterLLMNode",
display_name="OpenRouter LLM",
category="api node/text/OpenRouter",
essentials_category="Text Generation",
description=(
"Generate text responses through OpenRouter. Routes to a curated set of popular "
"models from xAI, DeepSeek, Qwen, Mistral, Z.AI (GLM), Moonshot (Kimi), and "
"Perplexity Sonar."
),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text input to the model.",
),
IO.DynamicCombo.Input(
"model",
options=_build_model_options(),
tooltip="The OpenRouter model used to generate the response.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for sampling. Set to 0 to omit. Most models treat this as a hint only.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
advanced=True,
tooltip="Foundational instructions that dictate the model's behavior.",
),
],
outputs=[IO.String.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr=_price_badge_jsonata(),
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
slug: str = model["model"]
spec = _MODELS_BY_SLUG.get(slug)
if spec is None:
raise ValueError(f"Unknown OpenRouter model: {slug}")
reasoning_effort: str | None = model.get("reasoning_effort")
search_context_size: str | None = model.get("search_context_size")
image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None]
if image_tensors and sum(get_number_of_images(t) for t in image_tensors) > spec.max_images:
raise ValueError(f"Up to {spec.max_images} images are supported for {slug}.")
video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None]
if video_inputs and len(video_inputs) > spec.max_videos:
raise ValueError(f"Up to {spec.max_videos} videos are supported for {slug}.")
media_blocks: list[OpenRouterContentBlock] = []
if image_tensors:
media_blocks.extend(await _build_image_blocks(cls, spec, image_tensors))
if video_inputs:
media_blocks.extend(await _build_video_blocks(cls, video_inputs))
request = _build_request(
slug,
system_prompt,
prompt,
media_blocks,
seed=seed,
reasoning_effort=reasoning_effort,
search_context_size=search_context_size,
)
response = await sync_op(
cls,
ApiEndpoint(path=OPENROUTER_CHAT_ENDPOINT, method="POST"),
response_model=OpenRouterChatResponse,
data=request,
price_extractor=_calculate_price,
)
return IO.NodeOutput(_extract_text(response))
class OpenRouterExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [OpenRouterLLMNode]
async def comfy_entrypoint() -> OpenRouterExtension:
return OpenRouterExtension()

View File

@ -16,16 +16,17 @@ from .conversions import (
convert_mask_to_image,
downscale_image_tensor,
downscale_image_tensor_by_max_side,
downscale_video_to_max_pixels,
image_tensor_pair_to_batch,
pil_to_bytesio,
resize_mask_to_image,
resize_video_to_pixel_budget,
tensor_to_base64_string,
tensor_to_bytesio,
tensor_to_pil,
text_filepath_to_base64_string,
text_filepath_to_data_uri,
trim_video,
upscale_video_to_min_pixels,
video_to_base64_string,
)
from .download_helpers import (
@ -88,16 +89,17 @@ __all__ = [
"convert_mask_to_image",
"downscale_image_tensor",
"downscale_image_tensor_by_max_side",
"downscale_video_to_max_pixels",
"image_tensor_pair_to_batch",
"pil_to_bytesio",
"resize_mask_to_image",
"resize_video_to_pixel_budget",
"tensor_to_base64_string",
"tensor_to_bytesio",
"tensor_to_pil",
"text_filepath_to_base64_string",
"text_filepath_to_data_uri",
"trim_video",
"upscale_video_to_min_pixels",
"video_to_base64_string",
# Validation utilities
"get_image_dimensions",

View File

@ -415,14 +415,48 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video:
"""Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio.
def downscale_video_to_max_pixels(video: Input.Video, max_pixels: int) -> Input.Video:
"""Downscale a video to fit within ``max_pixels`` (w * h), preserving aspect ratio.
Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio.
Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
"""
src_w, src_h = video.get_dimensions()
scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels)
scale_dims = _compute_downscale_dims(src_w, src_h, max_pixels)
if scale_dims is None:
return video
return _apply_video_scale(video, scale_dims)
def _compute_upscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None:
"""Return upscaled (w, h) with even dims meeting at least ``total_pixels``, or None if already large enough.
Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions
are rounded up to even values (many codecs require divisible-by-2). The result is guaranteed to be at
least ``total_pixels``.
"""
pixels = src_w * src_h
if pixels >= total_pixels:
return None
scale = math.sqrt(total_pixels / pixels)
new_w = math.ceil(src_w * scale)
new_h = math.ceil(src_h * scale)
if new_w % 2:
new_w += 1
if new_h % 2:
new_h += 1
return new_w, new_h
def upscale_video_to_min_pixels(video: Input.Video, min_pixels: int) -> Input.Video:
"""Upscale a video to meet at least ``min_pixels`` (w * h), preserving aspect ratio.
Returns the original video object untouched when it already meets the minimum. Preserves frame rate,
duration, and audio. Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
Note: upscaling a low-resolution source does not add real detail; downstream model quality may suffer.
"""
src_w, src_h = video.get_dimensions()
scale_dims = _compute_upscale_dims(src_w, src_h, min_pixels)
if scale_dims is None:
return video
return _apply_video_scale(video, scale_dims)

File diff suppressed because it is too large Load Diff

View File

@ -234,12 +234,6 @@ def save_glb(vertices, faces, filepath, metadata=None,
textures = []
samplers = []
materials = []
pbr = {
"metallicFactor": 0.0,
"roughnessFactor": 0.5,
"baseColorFactor": [0.22, 0.22, 0.22, 1.0],
}
if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
buffer_views.append({
"buffer": 0,
@ -249,13 +243,15 @@ def save_glb(vertices, faces, filepath, metadata=None,
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
textures.append({"source": 0, "sampler": 0})
pbr["baseColorTexture"] = {"index": 0, "texCoord": 0}
materials.append({
"pbrMetallicRoughness": pbr,
"doubleSided": True,
})
primitive["material"] = 0
materials.append({
"pbrMetallicRoughness": {
"baseColorTexture": {"index": 0, "texCoord": 0},
"metallicFactor": 0.0,
"roughnessFactor": 1.0,
},
"doubleSided": True,
})
primitive["material"] = 0
gltf = {
"asset": {"version": "2.0", "generator": "ComfyUI"},
@ -377,14 +373,10 @@ class SaveGLB(IO.ComfyNode):
continue
tex_img = Image.fromarray(texture_np[i], mode="RGB") if texture_np is not None else None
f = f"{filename}_{counter:05}_.glb"
save_glb(
vertices_i, faces_i,
os.path.join(full_output_folder, f),
metadata,
uvs=uvs_i,
vertex_colors=v_colors,
texture_image=tex_img,
)
save_glb(vertices_i, faces_i, os.path.join(full_output_folder, f), metadata,
uvs=uvs_i,
vertex_colors=v_colors,
texture_image=tex_img)
results.append({
"filename": f,
"subfolder": subfolder,

View File

@ -1,711 +0,0 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types, io
from comfy.ldm.trellis2.vae import SparseTensor
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
import comfy.model_management
from PIL import Image
import numpy as np
import torch
ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
def prepare_trellis_vae_for_decode(vae, sample_shape):
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
if len(sample_shape) == 5:
memory_required *= max(1, int(sample_shape[4]))
memory_required = max(1, int(memory_required))
device = comfy.model_management.get_torch_device()
comfy.model_management.load_models_gpu(
[vae.patcher],
memory_required=memory_required,
force_full_load=getattr(vae, "disable_offload", False),
)
free_memory = vae.patcher.get_free_memory(device)
batch_number = max(1, int(free_memory / memory_required))
return batch_number
shape_slat_normalization = {
"mean": torch.tensor([
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
])[None],
"std": torch.tensor([
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
])[None]
}
tex_slat_normalization = {
"mean": torch.tensor([
3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075,
0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149,
-1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717,
1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862
])[None],
"std": torch.tensor([
2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822,
2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588,
2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999,
2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190
])[None]
}
def shape_norm(shape_latent, coords):
std = shape_slat_normalization["std"].to(shape_latent)
mean = shape_slat_normalization["mean"].to(shape_latent)
samples = SparseTensor(feats = shape_latent, coords=coords)
samples = samples * std + mean
return samples
def infer_batched_coord_layout(coords):
if coords.ndim != 2 or coords.shape[1] != 4:
raise ValueError(f"Expected Trellis2 coords with shape [N, 4], got {tuple(coords.shape)}")
if coords.shape[0] == 0:
raise ValueError("Trellis2 coords can't be empty")
batch_ids = coords[:, 0].to(torch.int64)
if (batch_ids < 0).any():
raise ValueError(f"Trellis2 batch ids must be non-negative, got {batch_ids.unique(sorted=True).tolist()}")
batch_size = int(batch_ids.max().item()) + 1
counts = torch.bincount(batch_ids, minlength=batch_size)
if (counts == 0).any():
raise ValueError(f"Non-contiguous Trellis2 batch ids in coords: {batch_ids.unique(sorted=True).tolist()}")
max_tokens = int(counts.max().item())
return batch_size, counts, max_tokens
def split_batched_coords(coords, coord_counts):
if coord_counts.ndim != 1:
raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}")
if (coord_counts < 0).any():
raise ValueError(f"Trellis2 coord_counts must be non-negative, got {coord_counts.tolist()}")
if int(coord_counts.sum().item()) != coords.shape[0]:
raise ValueError(
f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}"
)
batch_ids = coords[:, 0].to(torch.int64)
order = torch.argsort(batch_ids, stable=True)
sorted_coords = coords.index_select(0, order)
sorted_batch_ids = batch_ids.index_select(0, order)
offsets = coord_counts.cumsum(0) - coord_counts
items = []
for i in range(coord_counts.shape[0]):
count = int(coord_counts[i].item())
start = int(offsets[i].item())
coords_i = sorted_coords[start:start + count]
ids_i = sorted_batch_ids[start:start + count]
if coords_i.shape[0] != count or not torch.all(ids_i == i):
raise ValueError(f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}")
items.append(coords_i)
return items
def flatten_batched_sparse_latent(samples, coords, coord_counts):
samples = samples.squeeze(-1).transpose(1, 2)
if coord_counts is None:
return samples.reshape(-1, samples.shape[-1]), coords
coords_items = split_batched_coords(coords, coord_counts)
feat_list = []
coord_list = []
for i, coords_i in enumerate(coords_items):
count = int(coord_counts[i].item())
feat_list.append(samples[i, :count])
coord_list.append(coords_i)
return torch.cat(feat_list, dim=0), torch.cat(coord_list, dim=0)
def split_batched_sparse_latent(samples, coords, coord_counts):
samples = samples.squeeze(-1).transpose(1, 2)
if coord_counts is None:
return [(samples.reshape(-1, samples.shape[-1]), coords)]
coords_items = split_batched_coords(coords, coord_counts)
items = []
for i, coords_i in enumerate(coords_items):
count = int(coord_counts[i].item())
items.append((samples[i, :count], coords_i))
return items
class VaeDecodeShapeTrellis(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VaeDecodeShapeTrellis",
category="latent/3d",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
],
outputs=[
IO.Mesh.Output("mesh"),
ShapeSubdivides.Output(display_name = "shape_subdivides"),
]
)
@classmethod
def execute(cls, samples, vae):
resolution = int(vae.first_stage_model.resolution.item())
sample_tensor = samples["samples"]
device = comfy.model_management.get_torch_device()
coords = samples["coords"]
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
trellis_vae = vae.first_stage_model
coord_counts = samples.get("coord_counts")
samples = samples["samples"]
if coord_counts is None:
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
samples = shape_norm(samples.to(device), coords.to(device))
mesh, subs = trellis_vae.decode_shape_slat(samples, resolution)
else:
split_items = split_batched_sparse_latent(samples, coords, coord_counts)
mesh = []
subs_per_sample = []
for feats_i, coords_i in split_items:
coords_i = coords_i.to(device).clone()
coords_i[:, 0] = 0
sample_i = shape_norm(feats_i.to(device), coords_i)
mesh_i, subs_i = trellis_vae.decode_shape_slat(sample_i, resolution)
mesh.append(mesh_i[0])
subs_per_sample.append(subs_i)
subs = []
for stage_index in range(len(subs_per_sample[0])):
stage_tensors = [sample_subs[stage_index] for sample_subs in subs_per_sample]
feats_list = [stage_tensor.feats for stage_tensor in stage_tensors]
coords_list = [stage_tensor.coords for stage_tensor in stage_tensors]
subs.append(SparseTensor.from_tensor_list(feats_list, coords_list))
face_list = [m.faces for m in mesh]
vert_list = [m.vertices for m in mesh]
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list))
else:
mesh = pack_variable_mesh_batch(vert_list, face_list)
return IO.NodeOutput(mesh, subs)
class VaeDecodeTextureTrellis(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VaeDecodeTextureTrellis",
category="latent/3d",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
ShapeSubdivides.Input("shape_subdivides",
tooltip=(
"Shape information used to guide higher-detail reconstruction during decoding. "
"Helps preserve structure consistency at higher resolutions."
)),
],
outputs=[
IO.Voxel.Output("voxel_colors"),
]
)
@classmethod
def execute(cls, samples, vae, shape_subdivides):
sample_tensor = samples["samples"]
device = comfy.model_management.get_torch_device()
coords = samples["coords"]
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
trellis_vae = vae.first_stage_model
coord_counts = samples.get("coord_counts")
samples = samples["samples"]
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
samples = samples.to(device)
std = tex_slat_normalization["std"].to(samples)
mean = tex_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords.to(device))
samples = samples * std + mean
voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides)
color_feats = voxel.feats[:, :3]
voxel_coords = voxel.coords#[:, 1:]
voxel = Types.VOXEL(voxel_coords, color_feats, 1024)
return IO.NodeOutput(voxel)
class VaeDecodeStructureTrellis2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VaeDecodeStructureTrellis2",
category="latent/3d",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
IO.Combo.Input("resolution", options=["32", "64"], default="32")
],
outputs=[
IO.Voxel.Output("voxel"),
]
)
@classmethod
def execute(cls, samples, vae, resolution):
resolution = int(resolution)
sample_tensor = samples["samples"]
sample_tensor = sample_tensor[:, :8]
batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
decoder = vae.first_stage_model.struct_dec
load_device = comfy.model_management.get_torch_device()
decoded_batches = []
for start in range(0, sample_tensor.shape[0], batch_number):
sample_chunk = sample_tensor[start:start + batch_number].to(load_device)
decoded_batches.append(decoder(sample_chunk) > 0)
decoded = torch.cat(decoded_batches, dim=0)
current_res = decoded.shape[2]
if current_res != resolution:
ratio = current_res // resolution
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
out = Types.VOXEL(decoded.squeeze(1).float())
return IO.NodeOutput(out)
class Trellis2UpsampleCascade(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Trellis2UpsampleCascade",
category="latent/3d",
display_name="Trellis2 Upsample Cascade",
description="Upsamples low-resolution Trellis2 shape latents into higher resolution coordinates while respecting the maximum token budget.",
inputs=[
IO.Latent.Input("shape_latent"),
IO.Vae.Input("vae"),
IO.Combo.Input("target_resolution", options=["1024", "1536"], default="1024", tooltip="Controls output detail level for upsampling."),
IO.Int.Input("max_tokens", default=49152, min=1024, max=100000,
tooltip=(
"Maximum number of output elements (coordinates) allowed after upsampling. "
"Used to limit memory usage and control mesh density."
))
],
outputs=[
IO.Voxel.Output(
"high_res_voxel",
tooltip=(
"High-resolution sparse coordinates produced after cascade upsampling. "
"Represents the refined 3D structure at target resolution."
)
)
]
)
@classmethod
def execute(cls, shape_latent, vae, target_resolution, max_tokens):
shape_latent_512 = shape_latent
device = comfy.model_management.get_torch_device()
prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape)
coord_counts = shape_latent_512.get("coord_counts")
decoder = vae.first_stage_model.shape_dec
lr_resolution = 512
target_resolution = int(target_resolution)
if coord_counts is None:
feats, coords_512 = flatten_batched_sparse_latent(
shape_latent_512["samples"],
shape_latent_512["coords"],
coord_counts,
)
feats = feats.to(device)
coords_512 = coords_512.to(device)
slat = shape_norm(feats, coords_512)
slat.feats = slat.feats.to(next(decoder.parameters()).dtype)
hr_coords = decoder.upsample(slat, upsample_times=4)
hr_resolution = target_resolution
while True:
quant_coords = torch.cat([
hr_coords[:, :1],
((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
], dim=1)
final_coords = quant_coords.unique(dim=0)
num_tokens = final_coords.shape[0]
if num_tokens < max_tokens or hr_resolution <= 1024:
break
hr_resolution -= 128
return IO.NodeOutput(final_coords,)
items = split_batched_sparse_latent(
shape_latent_512["samples"],
shape_latent_512["coords"],
coord_counts,
)
decoder_dtype = next(decoder.parameters()).dtype
sample_hr_coords = []
for feats_i, coords_i in items:
feats_i = feats_i.to(device)
coords_i = coords_i.to(device).clone()
coords_i[:, 0] = 0
slat_i = shape_norm(feats_i, coords_i)
slat_i.feats = slat_i.feats.to(decoder_dtype)
sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4))
hr_resolution = target_resolution
while True:
exceeds_limit = False
for hr_coords_i in sample_hr_coords:
quant_coords_i = torch.cat([
hr_coords_i[:, :1],
((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
], dim=1)
if quant_coords_i.unique(dim=0).shape[0] >= max_tokens:
exceeds_limit = True
break
if not exceeds_limit or hr_resolution <= 1024:
break
hr_resolution -= 128
final_coords_list = []
output_coord_counts = []
for sample_offset, hr_coords_i in enumerate(sample_hr_coords):
quant_coords_i = torch.cat([
hr_coords_i[:, :1],
((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
], dim=1)
final_coords_i = quant_coords_i.unique(dim=0)
final_coords_i = final_coords_i.clone()
final_coords_i[:, 0] = sample_offset
final_coords_list.append(final_coords_i)
output_coord_counts.append(int(final_coords_i.shape[0]))
coords = torch.cat(final_coords_list, dim=0)
output = Types.VOXEL(coords)
output.coord_counts = torch.tensor(output_coord_counts, dtype=torch.int64)
output.resolutions = torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64)
output.upsampled = True
return IO.NodeOutput(output,)
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
def run_conditioning(model, cropped_img_tensor, include_1024=True):
model_internal = model.model
device = comfy.model_management.intermediate_device()
torch_device = comfy.model_management.get_torch_device()
def prepare_tensor(pil_img, size):
resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS)
img_np = np.array(resized_pil).astype(np.float32) / 255.0
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
model_internal.image_size = 512
input_512 = prepare_tensor(cropped_img_tensor, 512)
cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0]
cond_1024 = None
if include_1024:
model_internal.image_size = 1024
input_1024 = prepare_tensor(cropped_img_tensor, 1024)
cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0]
conditioning = {
'cond_512': cond_512.to(device),
'neg_cond': torch.zeros_like(cond_512).to(device),
}
if cond_1024 is not None:
conditioning['cond_1024'] = cond_1024.to(device)
return conditioning
class Trellis2Conditioning(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Trellis2Conditioning",
category="conditioning/video_models",
inputs=[
IO.ClipVision.Input("clip_vision_model"),
IO.Image.Input("image"),
IO.Mask.Input("mask"),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
]
)
@classmethod
def execute(cls, clip_vision_model, image, mask) -> IO.NodeOutput:
# Normalize to batched form so per-image conditioning loop below is uniform.
if image.ndim == 3:
image = image.unsqueeze(0)
elif image.ndim == 4:
if image.shape[1] in [1, 3, 4] and image.shape[-1] not in [1, 3, 4]:
image = image.permute(0, 2, 3, 1)
# normalize mask to standard [B, H, W] (handling 2D, 3D, and 4D variants)
if mask.ndim == 4:
if mask.shape[1] == 1:
mask = mask.squeeze(1)
elif mask.shape[-1] == 1:
mask = mask.squeeze(-1)
else:
mask = mask[:, :, :, 0] # take first channel as fallback
if mask.ndim == 3:
if mask.shape[-1] == 1:
mask = mask.squeeze(-1).unsqueeze(0)
elif mask.ndim == 2:
mask = mask.unsqueeze(0)
batch_size = image.shape[0]
if mask.shape[0] == 1 and batch_size > 1:
mask = mask.expand(batch_size, -1, -1)
elif mask.shape[0] != batch_size:
raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}")
cond_512_list = []
cond_1024_list = []
for b in range(batch_size):
item_image = image[b]
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
# Ensure img_np is either 2D (grayscale) or 3D (RGB/RGBA)
if img_np.ndim == 3 and img_np.shape[-1] == 1:
img_np = img_np.squeeze(-1)
mask_np = mask_np.squeeze()
# detect inverted mask
border_pixels = np.concatenate([
mask_np[0, :], mask_np[-1, :], mask_np[:, 0], mask_np[:, -1]
])
if np.mean(border_pixels) > 127:
mask_np = 255 - mask_np
mask_np[mask_np < 35] = 0
border_shave = 4
mask_np[:border_shave, :] = 0
mask_np[-border_shave:, :] = 0
mask_np[:, :border_shave] = 0
mask_np[:, -border_shave:] = 0
pil_img = Image.fromarray(img_np)
pil_mask = Image.fromarray(mask_np)
max_size = max(pil_img.size)
scale = min(1.0, 1024 / max_size)
if scale < 1.0:
new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale)
pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
rgba_np[:, :, :3] = np.array(pil_img.convert("RGB"))
rgba_np[:, :, 3] = np.array(pil_mask)
alpha = rgba_np[:, :, 3]
bbox_coords = np.argwhere(alpha > 0.8 * 255)
if len(bbox_coords) > 0:
y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1])
y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1])
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
size = max(y_max - y_min, x_max - x_min)
crop_x1 = int(center_x - size // 2)
crop_y1 = int(center_y - size // 2)
crop_x2 = int(center_x + size // 2)
crop_y2 = int(center_y + size // 2)
rgba_pil = Image.fromarray(rgba_np)
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
else:
import logging
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
cropped_np = rgba_np.astype(np.float32) / 255.0
bg_rgb = np.array([0.0, 0.0, 0.0], dtype=np.float32)
fg = cropped_np[:, :, :3]
alpha_float = cropped_np[:, :, 3:4]
composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float)
# Keep the image as 4-channel RGBA to force TRELLIS to bypass its internal background remover
rgb_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
alpha_uint8 = (alpha_float.squeeze(-1) * 255.0).round().clip(0, 255).astype(np.uint8)
rgba_composite = np.zeros((cropped_np.shape[0], cropped_np.shape[1], 4), dtype=np.uint8)
rgba_composite[:, :, :3] = rgb_uint8
rgba_composite[:, :, 3] = alpha_uint8
cropped_pil = Image.fromarray(rgba_composite, mode="RGBA")
# Convert to RGB to ensure the CLIP/DINO model receives a 3-channel image
item_conditioning = run_conditioning(clip_vision_model, cropped_pil.convert("RGB"), include_1024=True)
cond_512_list.append(item_conditioning["cond_512"])
cond_1024_list.append(item_conditioning["cond_1024"])
cond_512_batched = torch.cat(cond_512_list, dim=0)
cond_1024_batched = torch.cat(cond_1024_list, dim=0)
neg_cond_batched = torch.zeros_like(cond_512_batched)
neg_embeds_batched = torch.zeros_like(cond_1024_batched)
positive = [[cond_512_batched, {"embeds": cond_1024_batched}]]
negative = [[neg_cond_batched, {"embeds": neg_embeds_batched}]]
return IO.NodeOutput(positive, negative)
class EmptyTrellis2ShapeLatent(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyTrellis2ShapeLatent",
category="latent/3d",
inputs=[
IO.Voxel.Input(
"voxel",
tooltip=(
"Shape structure input. Accepts either a voxel structure "
"or upsampled voxel coordinates from a previous cascade stage."
)
)
],
outputs=[
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, voxel):
# to accept the upscaled coords
is_512_pass = False
upsampled = hasattr(voxel, "upsampled")
if upsampled:
voxel = voxel.data
if not upsampled:
decoded = voxel.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
is_512_pass = True
else:
coords = voxel.int()
is_512_pass = False
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
in_channels = 32
# image like format
latent = torch.zeros(batch_size, in_channels, max_tokens, 1)
if is_512_pass:
generation_mode = "shape_generation_512"
else:
generation_mode = "shape_generation"
return IO.NodeOutput({"samples": latent, "coords": coords, "coord_counts": counts, "type": "trellis2",
"model_options": {"generation_mode": generation_mode, "coords": coords, "coord_counts": counts}})
class EmptyTrellis2LatentTexture(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyTrellis2LatentTexture",
category="latent/3d",
inputs=[
IO.Voxel.Input(
"voxel",
tooltip=(
"Shape structure input. Accepts either a voxel structure "
"or upsampled voxel coordinates from a previous cascade stage."
)
),
IO.Latent.Input("shape_latent"),
],
outputs=[
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, voxel, shape_latent):
channels = 32
upsampled = hasattr(voxel, "upsampled")
if upsampled:
voxel = voxel.data
if not upsampled:
decoded = voxel.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
else:
coords = voxel.int()
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
shape_latent = shape_latent["samples"]
if shape_latent.ndim == 4:
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
latent = torch.zeros(batch_size, channels, max_tokens, 1)
return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts,
"model_options": {"generation_mode": "texture_generation",
"coords": coords, "coord_counts": counts, "shape_slat": shape_latent}})
class EmptyTrellis2LatentStructure(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyTrellis2LatentStructure",
category="latent/3d",
inputs=[
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
],
outputs=[
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, batch_size):
in_channels = 8
resolution = 16
latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution)
output = {
"samples": latent,
"type": "trellis2",
}
return IO.NodeOutput(output)
class Trellis2Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
Trellis2Conditioning,
EmptyTrellis2ShapeLatent,
EmptyTrellis2LatentStructure,
EmptyTrellis2LatentTexture,
VaeDecodeTextureTrellis,
VaeDecodeShapeTrellis,
VaeDecodeStructureTrellis2,
Trellis2UpsampleCascade,
]
async def comfy_entrypoint() -> Trellis2Extension:
return Trellis2Extension()

33
main.py
View File

@ -27,6 +27,7 @@ from utils.mime_types import init_mime_types
import faulthandler
import logging
import sys
import traceback
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags
@ -135,7 +136,20 @@ def apply_custom_paths():
folder_paths.set_user_directory(user_dir)
# Buffer for prestartup failures. Recorded into `nodes.NODE_STARTUP_ERRORS`
# only AFTER the normal `import nodes` line below, so a failing prestartup
# script never triggers an early `import nodes` (and therefore `import torch`)
# on the error path.
_PRESTARTUP_FAILURES: list[dict] = []
def execute_prestartup_script():
"""Run every custom_nodes/*/prestartup_script.py once, before importing nodes.
Failures are buffered into the module-level ``_PRESTARTUP_FAILURES`` list and
must be flushed via ``record_node_startup_error`` after ``import nodes`` has
happened at its normal bootstrap point.
"""
if args.disable_all_custom_nodes and len(args.whitelist_custom_nodes) == 0:
return
@ -148,6 +162,15 @@ def execute_prestartup_script():
return True
except Exception as e:
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
# Buffer the failure - do NOT `import nodes` here, that would drag
# torch in before the intended bootstrap point.
_PRESTARTUP_FAILURES.append({
"module_path": os.path.dirname(script_path),
"source": "custom_nodes",
"phase": "prestartup",
"error": e,
"tb": traceback.format_exc(),
})
return False
node_paths = folder_paths.get_folder_paths("custom_nodes")
@ -207,6 +230,16 @@ import execution
import server
from protocol import BinaryEventTypes
import nodes
# Flush any prestartup failures that were buffered before `nodes` was
# importable. Doing this here (rather than from the prestartup error
# handler) keeps the bootstrap order deterministic: `nodes` (and torch)
# import at this single line whether prestartup succeeded or failed.
if _PRESTARTUP_FAILURES:
for _failure in _PRESTARTUP_FAILURES:
nodes.record_node_startup_error(**_failure)
_PRESTARTUP_FAILURES.clear()
import comfy.model_management
import comfyui_version
import app.logger

155
nodes.py
View File

@ -1537,10 +1537,6 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
if "model_options" in latent:
inner = model.model.diffusion_model
inner.meta = latent["model_options"]
callback = latent_preview.prepare_callback(model, steps)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
@ -2162,6 +2158,137 @@ EXTENSION_WEB_DIRS = {}
# Dictionary of successfully loaded module names and associated directories.
LOADED_MODULE_DIRS = {}
# Dictionary of custom node startup errors, keyed by "<source>:<module_name>"
# so that name collisions across custom_nodes / comfy_extras / comfy_api_nodes
# do not overwrite each other. Each value contains: source, module_name,
# module_path, error, traceback, phase.
#
# `source` is the same string as the internal `module_parent` used at load
# time (e.g. "custom_nodes", "comfy_extras", "comfy_api_nodes"). It is
# intentionally a free-form string rather than a fixed enum so the contract
# survives node-source layouts evolving (e.g. comfy_api_nodes eventually
# moving out of core). Consumers should treat any new value as a new bucket
# rather than rejecting it.
NODE_STARTUP_ERRORS: dict[str, dict] = {}
_EMPTY_LEAF_VALUES = (None, "", [], {})
def _prune_empty(value):
"""Recursively drop empty strings / lists / dicts / None from a nested structure.
Used to keep the on-wire pyproject payload tight without altering the
nesting that callers see (so consumers can still parse it back through
``PyProjectConfig`` if they want a typed object).
"""
if isinstance(value, dict):
out = {}
for k, v in value.items():
cleaned = _prune_empty(v)
if cleaned not in _EMPTY_LEAF_VALUES:
out[k] = cleaned
return out
if isinstance(value, list):
return [
cleaned
for cleaned in (_prune_empty(v) for v in value)
if cleaned not in _EMPTY_LEAF_VALUES
]
return value
def _read_pyproject_metadata(module_path: str) -> dict | None:
"""Best-effort extraction of pyproject.toml for a node module.
Returns a dict mirroring the ``PyProjectConfig`` shape produced by
``comfy_config.config_parser.extract_node_configuration`` (i.e. with
``project`` and ``tool_comfy`` nesting and the same field names) when the
module directory contains a pyproject.toml. Empty / default-valued leaves
are pruned so the API payload stays compact, but the nesting is kept
intact so API consumers can parse the result back through
``PyProjectConfig`` directly.
Returns None when no toml is present or parsing fails for any reason —
startup-error tracking must never itself raise.
"""
if not module_path or not os.path.isdir(module_path):
return None
toml_path = os.path.join(module_path, "pyproject.toml")
if not os.path.isfile(toml_path):
return None
try:
from comfy_config import config_parser
cfg = config_parser.extract_node_configuration(module_path)
if cfg is None:
return None
pruned = _prune_empty(cfg.model_dump())
return pruned or None
except Exception:
return None
def record_node_startup_error(
*, module_path: str, source: str, phase: str, error: BaseException, tb: str
) -> None:
"""Record a startup error for a node module so it can be exposed via the API."""
module_name = get_module_name(module_path)
entry = {
"source": source,
"module_name": module_name,
"module_path": module_path,
"error": str(error),
"traceback": tb,
"phase": phase,
}
pyproject = _read_pyproject_metadata(module_path)
if pyproject:
entry["pyproject"] = pyproject
NODE_STARTUP_ERRORS[f"{source}:{module_name}"] = entry
def filter_node_startup_errors(
*,
source: str | None = None,
module_name: str | None = None,
pack_id: str | None = None,
) -> dict[str, dict[str, dict]]:
"""Return `NODE_STARTUP_ERRORS` reshaped for the public HTTP endpoint.
Entries are grouped by their ``source`` bucket (the same string as the
internal ``module_parent`` used at load time). The on-disk
``module_path`` is stripped from each entry — it's an internal detail
useful only for server-side logging and would leak absolute filesystem
layout otherwise.
Optional filters narrow the response and combine with AND:
* ``source`` — only entries from this source bucket.
* ``module_name`` — only entries whose module name matches exactly.
* ``pack_id`` — only entries whose ``pyproject.project.name``
matches exactly. Entries without a parsed
pyproject.toml can never match this filter.
A non-matching filter returns an empty dict, not an error — absence of
a failure is a valid answer for this query.
"""
grouped: dict[str, dict[str, dict]] = {}
for entry in NODE_STARTUP_ERRORS.values():
entry_source = entry.get("source", "custom_nodes")
if source is not None and entry_source != source:
continue
if module_name is not None and entry.get("module_name") != module_name:
continue
if pack_id is not None:
pyproject = entry.get("pyproject") or {}
project = pyproject.get("project") or {}
if project.get("name") != pack_id:
continue
public_entry = {k: v for k, v in entry.items() if k != "module_path"}
grouped.setdefault(entry_source, {})[entry["module_name"]] = public_entry
return grouped
def get_module_name(module_path: str) -> str:
"""
@ -2271,14 +2398,30 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
return True
except Exception as e:
tb = traceback.format_exc()
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
record_node_startup_error(
module_path=module_path,
source=module_parent,
phase="entrypoint",
error=e,
tb=tb,
)
return False
else:
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or comfy_entrypoint (need one).")
return False
except Exception as e:
logging.warning(traceback.format_exc())
tb = traceback.format_exc()
logging.warning(tb)
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
record_node_startup_error(
module_path=module_path,
source=module_parent,
phase="import",
error=e,
tb=tb,
)
return False
async def init_external_custom_nodes():
@ -2434,8 +2577,6 @@ async def init_builtin_extra_nodes():
"nodes_toolkit.py",
"nodes_replacements.py",
"nodes_nag.py",
"nodes_trellis2.py",
"nodes_mesh_postprocess.py",
"nodes_sdpose.py",
"nodes_math.py",
"nodes_number_convert.py",

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.43.18
comfyui-workflow-templates==0.9.79
comfyui-workflow-templates==0.9.82
comfyui-embedded-docs==0.5.0
torch
torchsde

View File

@ -765,6 +765,46 @@ class PromptServer():
out[node_class] = node_info(node_class)
return web.json_response(out)
@routes.get("/node_startup_errors")
async def get_node_startup_errors(request):
"""Return startup errors recorded during node loading, grouped by source.
Group errors by source so the frontend/Manager can render them in
distinct sections. ``source`` is the same string as the
``module_parent`` used at load time (e.g. ``"custom_nodes"``,
``"comfy_extras"``, ``"comfy_api_nodes"``) and is left as a
free-form string so the contract survives node-source layouts
evolving. The response only contains source buckets that actually
had a failure; consumers should not assume any particular set of
keys is always present.
``module_path`` is stripped because the absolute on-disk path is
internal detail that the frontend has no use for.
Optional query parameters narrow the response:
* ``source`` — only entries from this source bucket.
* ``module_name`` — only entries whose module name matches exactly.
(Folder name for directory-style packs, file
stem for single-file modules.)
* ``pack_id`` — only entries whose ``pyproject.project.name``
matches exactly. Entries without a parsed
pyproject.toml are skipped under this filter.
Filters are combined with AND. Filtering an empty / non-matching
result still returns ``{}`` with HTTP 200 rather than 404 — absence
of an error is a valid answer for this endpoint.
"""
# Coalesce empty-string query values to None so `?source=` (param
# present but blank) is treated the same as the param being absent
# — rather than filtering for entries whose source is literally "".
grouped = nodes.filter_node_startup_errors(
source=request.query.get("source") or None,
module_name=request.query.get("module_name") or None,
pack_id=request.query.get("pack_id") or None,
)
return web.json_response(grouped)
@routes.get("/api/jobs")
async def get_jobs(request):
"""List all jobs with filtering, sorting, and pagination.

View File

@ -0,0 +1,258 @@
"""Tests for the custom node startup error tracking introduced for
Comfy-Org/ComfyUI-Launcher#303.
Covers:
- load_custom_node populates NODE_STARTUP_ERRORS with the correct source
for each module_parent (custom_nodes / comfy_extras / comfy_api_nodes).
- Composite keying prevents collisions between modules with the same name
in different sources.
- record_node_startup_error stores the expected fields.
- pyproject.toml metadata is attached when present and omitted when absent.
"""
import textwrap
import pytest
import nodes
@pytest.fixture(autouse=True)
def _clear_startup_errors():
nodes.NODE_STARTUP_ERRORS.clear()
yield
nodes.NODE_STARTUP_ERRORS.clear()
def _write_broken_module(tmp_path, name: str) -> str:
path = tmp_path / f"{name}.py"
path.write_text(textwrap.dedent("""\
# Deliberately broken module to exercise startup-error tracking.
raise RuntimeError("boom from " + __name__)
"""))
return str(path)
def test_record_node_startup_error_fields(tmp_path):
err = ValueError("kaboom")
nodes.record_node_startup_error(
module_path=str(tmp_path / "my_pack"),
source="custom_nodes",
phase="import",
error=err,
tb="traceback-text",
)
assert "custom_nodes:my_pack" in nodes.NODE_STARTUP_ERRORS
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:my_pack"]
assert entry["source"] == "custom_nodes"
assert entry["module_name"] == "my_pack"
assert entry["phase"] == "import"
assert entry["error"] == "kaboom"
assert entry["traceback"] == "traceback-text"
assert entry["module_path"].endswith("my_pack")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"module_parent",
["custom_nodes", "comfy_extras", "comfy_api_nodes"],
)
async def test_load_custom_node_records_source(tmp_path, module_parent):
# `source` in the entry should be the same string as `module_parent`.
module_path = _write_broken_module(tmp_path, "broken_pack")
success = await nodes.load_custom_node(module_path, module_parent=module_parent)
assert success is False
key = f"{module_parent}:broken_pack"
assert key in nodes.NODE_STARTUP_ERRORS, nodes.NODE_STARTUP_ERRORS
entry = nodes.NODE_STARTUP_ERRORS[key]
assert entry["source"] == module_parent
assert entry["module_name"] == "broken_pack"
assert entry["phase"] == "import"
assert "boom from" in entry["error"]
assert "RuntimeError" in entry["traceback"]
@pytest.mark.asyncio
async def test_load_custom_node_collision_across_sources(tmp_path):
# Same module name registered as both a custom node and a comfy_extra;
# composite keying should keep both entries.
cn_dir = tmp_path / "cn"
extras_dir = tmp_path / "extras"
cn_dir.mkdir()
extras_dir.mkdir()
cn_path = _write_broken_module(cn_dir, "nodes_audio")
extras_path = _write_broken_module(extras_dir, "nodes_audio")
assert await nodes.load_custom_node(cn_path, module_parent="custom_nodes") is False
assert await nodes.load_custom_node(extras_path, module_parent="comfy_extras") is False
assert "custom_nodes:nodes_audio" in nodes.NODE_STARTUP_ERRORS
assert "comfy_extras:nodes_audio" in nodes.NODE_STARTUP_ERRORS
assert (
nodes.NODE_STARTUP_ERRORS["custom_nodes:nodes_audio"]["module_path"]
!= nodes.NODE_STARTUP_ERRORS["comfy_extras:nodes_audio"]["module_path"]
)
@pytest.mark.asyncio
async def test_load_custom_node_attaches_pyproject_metadata(tmp_path):
pack_dir = tmp_path / "MyCoolPack"
pack_dir.mkdir()
(pack_dir / "__init__.py").write_text("raise RuntimeError('boom')\n")
(pack_dir / "pyproject.toml").write_text(textwrap.dedent("""\
[project]
name = "comfyui-mycoolpack"
version = "1.2.3"
[project.urls]
Repository = "https://github.com/example/comfyui-mycoolpack"
[tool.comfy]
PublisherId = "example"
DisplayName = "My Cool Pack"
"""))
success = await nodes.load_custom_node(str(pack_dir), module_parent="custom_nodes")
assert success is False
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:MyCoolPack"]
assert "pyproject" in entry, entry
py = entry["pyproject"]
# Shape must mirror PyProjectConfig 1:1 so consumers can parse it back
# through the same pydantic model used by comfy_config.config_parser.
project = py["project"]
assert project["name"] == "comfyui-mycoolpack"
assert project["version"] == "1.2.3"
assert project["urls"]["repository"] == "https://github.com/example/comfyui-mycoolpack"
tool_comfy = py["tool_comfy"]
assert tool_comfy["publisher_id"] == "example"
assert tool_comfy["display_name"] == "My Cool Pack"
def test_prune_empty_drops_empty_leaves_only():
src = {
"keep_str": "x",
"drop_empty_str": "",
"drop_none": None,
"drop_empty_list": [],
"drop_empty_dict": {},
"keep_zero": 0,
"keep_false": False,
"nested": {
"drop_me": "",
"keep_me": "y",
"deeper": {"only_empties": ""},
},
"list_of_dicts": [{"a": ""}, {"a": "z"}],
}
result = nodes._prune_empty(src)
assert result == {
"keep_str": "x",
"keep_zero": 0,
"keep_false": False,
"nested": {"keep_me": "y"},
"list_of_dicts": [{"a": "z"}],
}
@pytest.mark.asyncio
async def test_load_custom_node_no_pyproject_skips_metadata(tmp_path):
# Single-file extras-style module: no pyproject.toml exists alongside it,
# so the entry must not contain a 'pyproject' key.
module_path = _write_broken_module(tmp_path, "lonely")
assert await nodes.load_custom_node(module_path, module_parent="comfy_extras") is False
entry = nodes.NODE_STARTUP_ERRORS["comfy_extras:lonely"]
assert "pyproject" not in entry
@pytest.mark.asyncio
async def test_load_custom_node_arbitrary_module_parent_passes_through(tmp_path):
# `source` is a free-form string — an unknown module_parent (e.g. a future
# node-source bucket) should be recorded as-is, not coerced or rejected.
module_path = _write_broken_module(tmp_path, "future_pack")
assert await nodes.load_custom_node(module_path, module_parent="future_source") is False
entry = nodes.NODE_STARTUP_ERRORS["future_source:future_pack"]
assert entry["source"] == "future_source"
# ---------------------------------------------------------------------------
# Tests for the public reshape/filter helper (nodes.filter_node_startup_errors).
# The HTTP route is a thin wrapper around this helper, so unit-testing it
# directly avoids spinning up an aiohttp app while still covering every
# query-param branch.
# ---------------------------------------------------------------------------
def _seed(*, source, module_name, pack_id=None, module_path="/abs/path"):
"""Insert a synthetic entry directly into NODE_STARTUP_ERRORS."""
entry = {
"source": source,
"module_name": module_name,
"module_path": module_path,
"error": "boom",
"traceback": "tb",
"phase": "import",
}
if pack_id is not None:
entry["pyproject"] = {"project": {"name": pack_id}}
nodes.NODE_STARTUP_ERRORS[f"{source}:{module_name}"] = entry
def test_filter_node_startup_errors_strips_module_path_and_groups_by_source():
_seed(source="custom_nodes", module_name="A", module_path="/x/A")
_seed(source="comfy_extras", module_name="B", module_path="/x/B")
grouped = nodes.filter_node_startup_errors()
assert set(grouped) == {"custom_nodes", "comfy_extras"}
assert "module_path" not in grouped["custom_nodes"]["A"]
assert "module_path" not in grouped["comfy_extras"]["B"]
def test_filter_node_startup_errors_source_filter():
_seed(source="custom_nodes", module_name="A")
_seed(source="comfy_extras", module_name="B")
grouped = nodes.filter_node_startup_errors(source="comfy_extras")
assert set(grouped) == {"comfy_extras"}
assert set(grouped["comfy_extras"]) == {"B"}
# Non-matching source filter returns an empty dict, not an error.
assert nodes.filter_node_startup_errors(source="nope") == {}
# An explicit empty-string filter is treated as a real value (matches
# entries whose source is literally ""), NOT silently as "no filter".
# The HTTP route layer is responsible for coalescing `?source=` to None
# before calling this helper; this assertion locks that contract in.
assert nodes.filter_node_startup_errors(source="") == {}
def test_filter_node_startup_errors_module_name_filter():
_seed(source="custom_nodes", module_name="A")
_seed(source="comfy_extras", module_name="A") # same name, different source
_seed(source="custom_nodes", module_name="C")
grouped = nodes.filter_node_startup_errors(module_name="A")
# Both A entries (from different sources) survive the filter and stay in
# their respective source buckets.
assert set(grouped) == {"custom_nodes", "comfy_extras"}
assert set(grouped["custom_nodes"]) == {"A"}
assert set(grouped["comfy_extras"]) == {"A"}
def test_filter_node_startup_errors_pack_id_filter_matches_only_pyproject_entries():
_seed(source="custom_nodes", module_name="A", pack_id="comfyui-foo")
_seed(source="custom_nodes", module_name="B", pack_id="comfyui-bar")
_seed(source="comfy_extras", module_name="C") # no pyproject at all
grouped = nodes.filter_node_startup_errors(pack_id="comfyui-foo")
assert set(grouped) == {"custom_nodes"}
assert set(grouped["custom_nodes"]) == {"A"}
# An entry without a parsed pyproject can never match a pack_id filter.
assert nodes.filter_node_startup_errors(pack_id="anything-else") == {}
def test_filter_node_startup_errors_filters_combine_with_and():
_seed(source="custom_nodes", module_name="A", pack_id="comfyui-foo")
_seed(source="comfy_extras", module_name="A", pack_id="comfyui-foo")
grouped = nodes.filter_node_startup_errors(
source="comfy_extras", pack_id="comfyui-foo"
)
assert set(grouped) == {"comfy_extras"}
assert set(grouped["comfy_extras"]) == {"A"}