Compare commits

..

4 Commits

Author SHA1 Message Date
078d544705 openapi: document Cloud-runtime request fields on POST /api/assets/export (#14120)
The Cloud runtime accepts three request fields on /api/assets/export that the
spec didn't declare:

- job_ids: include all assets associated with the given jobs
- naming_strategy: how to name files in the ZIP (enum, default group_by_job_time)
- job_asset_name_filters: optional per-job asset-name allowlist

Also drops asset_ids from required: the runtime supports exporting by job_ids
alone, so neither field is individually required.

/api/assets/export is already x-runtime: [cloud]; these are plain field
additions under that endpoint-level tag.
2026-05-26 14:32:22 -07:00
011c6bf101 openapi: add Cloud-runtime fields workflow_id, execution_error to JobEntry
The Cloud runtime returns two additional fields on JobEntry that the spec
didn't declare:

- workflow_id: UUID of the Cloud workflow entity the job is associated with
- execution_error: structured ComfyUI execution error for failed jobs
  (reuses the existing ExecutionError schema)

Both tagged x-runtime: [cloud] with [cloud-only] descriptions; local ComfyUI
does not populate them.
2026-05-26 14:25:38 -07:00
cabccdeb38 openapi: fix GET /api/hub/labels response to the label-catalog shape
GET /api/hub/labels returns the catalog of available labels you can filter by,
which the Cloud runtime serves as {labels: HubLabelInfo[]} (slug name,
display_name, and a type category: tag/model/custom_node).

The spec had this operation returning a bare array of HubLabel ({id, name,
color}) — that schema models the label chips attached to a published workflow
(HubWorkflow.labels), a different object. The catalog schema (HubLabelInfo)
already existed but was unreferenced.

Repoints the 200 response to a new HubLabelListResponse wrapper over the
existing HubLabelInfo. HubLabel is unchanged and still used by
HubWorkflow.labels. Endpoint remains x-runtime: [cloud].
2026-05-26 14:25:38 -07:00
e9e30553ca openapi: document QueueManageResponse body on POST /api/queue
The Cloud runtime returns a JSON body from POST /api/queue describing which
prompts were deleted and whether the queue was cleared. The spec previously
declared a bare 200 with no schema, so generated clients had no type for the
response.

Adds a QueueManageResponse schema ({deleted, cleared}) and references it from
the 200 response. Tagged x-runtime: [cloud] with a [cloud-only] description:
local ComfyUI returns an empty 200 body, so both fields are nullable.
2026-05-26 14:25:38 -07:00
24 changed files with 44 additions and 1834 deletions

View File

@ -799,15 +799,13 @@ class ZImagePixelSpace(ChromaRadiance):
"""
pass
class HiDreamO1Pixel(ChromaRadiance):
"""Pixel-space latent format for HiDream-O1.
No VAE — model patches/unpatches raw RGB internally with patch_size=32.
"""
pass
class PixelDiTPixel(ChromaRadiance):
pass
class CogVideoX(LatentFormat):
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).

View File

@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000):
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
super().__init__()
if output_size is None:
output_size = hidden_size
@ -221,10 +221,9 @@ class TimestepEmbedder(nn.Module):
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
def forward(self, t, dtype, **kwargs):
t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype)
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb

View File

@ -1,239 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope, rope
from comfy.ldm.hidream.model import FeedForwardSwiGLU
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from .modules import (
FinalLayer,
PatchTokenEmbedder,
PiTBlock,
PixelTokenEmbedder,
apply_adaln_,
precompute_freqs_cis_2d,
)
class MMDiTJointAttention(nn.Module):
"""Joint MMDiT attention with separate Q/K/V/proj for image and text streams.
RoPE is applied to each stream before concatenation so each stream uses its own
2D/1D positional encoding. Concat order is [text, image] (text first).
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device)
self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
B, Nx, _ = x.shape
_, Ny, _ = y.shape
H = self.num_heads
D = self.head_dim
qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4)
qx, kx, vx = qkv_x.unbind(0)
qx = self.q_norm_x(qx)
kx = self.k_norm_x(kx)
qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4)
qy, ky, vy = qkv_y.unbind(0)
qy = self.q_norm_y(qy)
ky = self.k_norm_y(ky)
qx, kx = apply_rope(qx, kx, pos_img[None, None])
if pos_txt is not None:
qy, ky = apply_rope(qy, ky, pos_txt[None, None])
q_joint = torch.cat([qy, qx], dim=2)
k_joint = torch.cat([ky, kx], dim=2)
v_joint = torch.cat([vy, vx], dim=2)
out_joint = optimized_attention(
q_joint, k_joint, v_joint, H,
mask=attn_mask, skip_reshape=True, skip_output_reshape=True,
transformer_options=transformer_options,
)
out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D)
out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D)
return self.proj_x(out_x), self.proj_y(out_y)
class MMDiTBlockT2I(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None):
super().__init__()
self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False, dtype=dtype, device=device, operations=operations)
self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
self.adaLN_modulation_img = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
self.adaLN_modulation_txt = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x)
y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y)
attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
x = torch.addcmul(x, gate_msa_x, attn_x)
y = torch.addcmul(y, gate_msa_y, attn_y)
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
return x, y
class PixDiT_T2I(nn.Module):
"""PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint
(also runs at 512px when fed the appropriate latent size and flow_shift).
Forward:
x: [B, 3, H, W] pixel-space input (no VAE)
timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention)
context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended)
Returns flow-matching velocity [B, 3, H, W].
"""
def __init__(
self,
in_channels=3,
num_groups=24,
hidden_size=1536,
pixel_hidden_size=16,
pixel_attn_hidden_size=1152,
pixel_num_groups=16,
patch_depth=14,
pixel_depth=2,
patch_size=16,
txt_embed_dim=2304,
txt_max_length=300,
use_text_rope=True,
text_rope_theta=10000.0,
image_model=None,
dtype=None,
device=None,
operations=None,
pixel_mlp_chunks=2,
):
super().__init__()
self.dtype = dtype
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.patch_depth = patch_depth
self.pixel_depth = pixel_depth
self.patch_size = patch_size
self.pixel_hidden_size = pixel_hidden_size
self.pixel_attn_hidden_size = pixel_attn_hidden_size
self.pixel_num_groups = pixel_num_groups
self.txt_embed_dim = txt_embed_dim
self.txt_max_length = txt_max_length
self.use_text_rope = use_text_rope
self.text_rope_theta = text_rope_theta
self.pixel_embedder = PixelTokenEmbedder(self.in_channels, self.pixel_hidden_size, dtype=dtype, device=device, operations=operations)
self.s_embedder = PatchTokenEmbedder(self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True, dtype=dtype, device=device, operations=operations)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10)
self.y_embedder = PatchTokenEmbedder(self.txt_embed_dim, self.hidden_size, bias=True, use_norm=True, dtype=dtype, device=device, operations=operations)
self.y_pos_embedding = nn.Parameter(torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device))
self.patch_blocks = nn.ModuleList([
MMDiTBlockT2I(self.hidden_size, self.num_groups,
dtype=dtype, device=device, operations=operations)
for _ in range(self.patch_depth)
])
self.pixel_blocks = nn.ModuleList([
PiTBlock(
self.pixel_hidden_size,
self.hidden_size,
patch_size=self.patch_size,
num_heads=self.num_groups,
attn_hidden_size=self.pixel_attn_hidden_size,
attn_num_heads=self.pixel_num_groups,
dtype=dtype, device=device, operations=operations,
mlp_chunks=pixel_mlp_chunks,
)
for _ in range(self.pixel_depth)
])
self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels, dtype=dtype, device=device, operations=operations)
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
return precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width, device=device, dtype=dtype, **rope_opts)
def _fetch_text_pos(self, length, device, dtype):
return rope(torch.arange(length, dtype=torch.float32, device=device).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0).to(dtype=dtype)
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
def _pre_patch_block(self, s, i, **kwargs):
"""Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate)."""
return s
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
H_orig, W_orig = x.shape[2], x.shape[3]
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
B, _, H, W = x.shape
Hs = H // self.patch_size
Ws = W // self.patch_size
L = Hs * Ws
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
if context is None or context.dim() != 3:
raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]")
Ltxt = min(context.shape[1], self.txt_max_length)
y = context[:, :Ltxt, :]
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter
condition = F.silu(t_emb)
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
s = self.s_embedder(x_patches)
for i, blk in enumerate(self.patch_blocks):
s = self._pre_patch_block(s, i, **kwargs)
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
s = F.silu(t_emb + s)
s_cond = s.view(B * L, self.hidden_size)
x_pixels = self.pixel_embedder(x, patch_size=self.patch_size)
for blk in self.pixel_blocks:
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options)
x_pixels = self.final_layer(x_pixels)
C_out = self.out_channels
P2 = self.patch_size * self.patch_size
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).reshape(B, C_out * P2, L)
out = F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return out[:, :, :H_orig, :W_orig]

View File

@ -1,187 +0,0 @@
import torch
import torch.nn as nn
from comfy.ldm.flux.math import apply_rope, rope
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, get_1d_sincos_pos_embed_from_grid_torch
def apply_adaln_(x, shift, scale):
return x.addcmul_(x, scale).add_(shift)
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0,
ref_grid_h=None, ref_grid_w=None,
scale_x=1.0, scale_y=1.0, shift_x=0.0, shift_y=0.0,
device=None, dtype=torch.float32, **kwargs):
"""2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim.
rope_options:
scale_x / scale_y multiply the position range (RoPE extrapolation).
shift_x / shift_y offset the position origin (tiled / regional inference).
With ref_grid_h/w set, also applies NTK-aware per-axis theta scaling
(rope_mode='ntk_aware'): theta_axis = theta * (current/ref)^(dim_axis/(dim_axis-2)).
Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2].
Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}].
"""
dim_axis = dim // 2
if ref_grid_h is not None and dim_axis > 2:
h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2))
w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2))
else:
h_ntk = w_ntk = 1.0
x_lin = torch.linspace(shift_x, scale * scale_x + shift_x, width, device=device)
y_lin = torch.linspace(shift_y, scale * scale_y + shift_y, height, device=device)
y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij")
x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0)
y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0)
out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2)
return out.to(dtype=dtype)
def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32):
"""Standard 2D sin/cos absolute positional embedding (ViT-style).
first half encodes W-coordinates, second half H.
"""
assert embed_dim % 4 == 0
grid_h = torch.arange(height, dtype=torch.float32, device=device)
grid_w = torch.arange(width, dtype=torch.float32, device=device)
grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij")
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_x.reshape(-1), device=device)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_y.reshape(-1), device=device)
return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype)
class RotaryAttention(nn.Module):
"""Single-stream self-attention with rotary positional encoding (used inside PiTBlock)."""
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x, pos, mask=None, transformer_options={}):
B, N, C = x.shape
H = self.num_heads
D = self.head_dim
qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = apply_rope(self.q_norm(q), self.k_norm(k), pos[None, None])
x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options)
return self.proj(x)
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, x):
return self.linear(self.norm(x))
class PatchTokenEmbedder(nn.Module):
"""Linear projection used both for patchified-image tokens and text-feature tokens."""
def __init__(self, in_chans, embed_dim, use_norm=False, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device)
self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device) if use_norm else nn.Identity()
def forward(self, x):
return self.norm(self.proj(x))
class PixelTokenEmbedder(nn.Module):
"""Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences."""
def __init__(self, in_channels, hidden_size_output, dtype=None, device=None, operations=None):
super().__init__()
self.in_channels = in_channels
self.hidden_size_output = hidden_size_output
self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device)
def forward(self, inputs, patch_size):
B, _, H, W = inputs.shape
Hs, Ws = H // patch_size, W // patch_size
P2 = patch_size * patch_size
x = inputs.permute(0, 2, 3, 1).contiguous()
x = self.proj(x)
pos_full = get_2d_sincos_pos_embed(self.hidden_size_output, H, W, device=x.device, dtype=x.dtype).view(H, W, self.hidden_size_output)
x = x + pos_full.unsqueeze(0)
x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output)
return x.permute(0, 1, 3, 2, 4, 5).reshape(B * Hs * Ws, P2, self.hidden_size_output)
class PiTBlock(nn.Module):
"""Pixel-level transformer block.
Compresses each patch's P^2 pixel tokens → 1 attention token via a linear,
runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens.
Conditioning is per-pixel adaLN from the patch-level features.
"""
def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
attn_hidden_size=None, attn_num_heads=None, dtype=None, device=None, operations=None, mlp_chunks=1):
super().__init__()
self.pixel_dim = pixel_hidden_size
self.context_dim = patch_hidden_size
self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size
self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads
assert self.attn_dim % self.num_heads == 0
p2 = patch_size * patch_size
self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device)
self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device)
self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False, dtype=dtype, device=device, operations=operations)
self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), dtype=dtype, device=device, operations=operations)
self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
self._rope_fn = precompute_freqs_cis_2d
self.mlp_chunks = max(1, int(mlp_chunks))
def _fetch_pos(self, height, width, device, dtype, **rope_opts):
return self._rope_fn(self.attn_dim // self.num_heads, height, width, device=device, dtype=dtype, **rope_opts)
def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}):
BL, P2, _ = x.shape
Hs, Ws = image_height // patch_size, image_width // patch_size
L = Hs * Ws
B = BL // L
# Attention path uses only msa params; compute, use, free before mlp params allocate.
msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim)
shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1)
x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa)
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options)
attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
x = torch.addcmul(x, gate_msa, attn_exp)
del msa_params, shift_msa, scale_msa, gate_msa
mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
gate_mlp = gate_mlp.contiguous() # detach from mlp_params so the del below frees shift+scale storage before the MLP
mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp)
del mlp_params, shift_mlp, scale_mlp
# MLP in chunks since the peak memory usage is huge here
chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks
for s in range(0, BL, chunk_size):
e = min(s + chunk_size, BL)
x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e]))
return x

View File

@ -1,227 +0,0 @@
"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent
directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
body + LQ projection branch injected before each MMDiT patch block.
"""
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from .model import PixDiT_T2I
from .modules import precompute_freqs_cis_2d
class SigmaAwareGatePerTokenPerDim(nn.Module):
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1.
"""
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device)
self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device))
def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
content_logit = self.content_proj(torch.cat([x, lq], dim=-1))
# log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32)
sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1)
gate = torch.sigmoid(content_logit + sigma_offset)
return x + (gate * lq).to(x.dtype)
class ResBlock(nn.Module):
"""Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip."""
def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None):
super().__init__()
self.block = nn.Sequential(
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class LQProjection2D(nn.Module):
"""LQ latent -> per-block patch-aligned features for controlnet-style injection."""
def __init__(
self,
latent_channels: int,
hidden_dim: int = 512,
out_dim: int = 1536,
patch_size: int = 16,
sr_scale: int = 4,
latent_spatial_down_factor: int = 8,
num_res_blocks: int = 4,
num_outputs: int = 7,
interval: int = 2,
dtype=None, device=None, operations=None,
):
super().__init__()
self.latent_channels = latent_channels
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.patch_size = patch_size
self.sr_scale = sr_scale
self.latent_spatial_down_factor = latent_spatial_down_factor
self.num_outputs = num_outputs
self.interval = interval
z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size
self.z_to_patch_ratio = z_to_patch_ratio
if z_to_patch_ratio >= 1:
self.latent_fold_factor = 0
latent_proj_in_ch = latent_channels
else:
fold_factor = int(1 / z_to_patch_ratio)
assert fold_factor * z_to_patch_ratio == 1.0
self.latent_fold_factor = fold_factor
latent_proj_in_ch = latent_channels * fold_factor * fold_factor
layers = [
operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
]
for _ in range(num_res_blocks):
layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations))
self.latent_proj = nn.Sequential(*layers)
self.output_heads = nn.ModuleList(
[operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)]
)
self.gate_modules = nn.ModuleList(
[SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations)
for _ in range(num_outputs)]
)
def is_gate_active(self, block_idx: int) -> bool:
return block_idx % self.interval == 0
def output_index(self, block_idx: int) -> int:
return block_idx // self.interval
def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor:
return self.gate_modules[out_idx](x, lq_feature, sigma)
def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor:
B, z_dim = lq_latent.shape[:2]
if self.z_to_patch_ratio >= 1:
if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW:
z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest")
else:
z_aligned = lq_latent
else:
f = self.latent_fold_factor
zH_expected, zW_expected = pH * f, pW * f
if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected:
lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest")
z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4)
z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW)
return self.latent_proj(z_aligned)
def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]:
feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW)
B, C, H, W = feat.shape
tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)
return [head(tokens) for head in self.output_heads]
class PidNet(PixDiT_T2I):
"""PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block)."""
def __init__(
self,
lq_latent_channels: int = 16,
lq_hidden_dim: int = 512,
lq_num_res_blocks: int = 4,
lq_interval: int = 2,
sr_scale: int = 4,
latent_spatial_down_factor: int = 8,
rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64.
rope_ref_w: int = 1024,
image_model=None,
dtype=None, device=None, operations=None,
**pixdit_kwargs,
):
super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs)
self.rope_ref_grid_h = rope_ref_h // self.patch_size
self.rope_ref_grid_w = rope_ref_w // self.patch_size
# Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware.
def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts):
return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts)
for blk in self.pixel_blocks:
blk._rope_fn = _pit_rope_fn
num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval
self.lq_proj = LQProjection2D(
latent_channels=lq_latent_channels,
hidden_dim=lq_hidden_dim,
out_dim=self.hidden_size,
patch_size=self.patch_size,
sr_scale=sr_scale,
latent_spatial_down_factor=latent_spatial_down_factor,
num_res_blocks=lq_num_res_blocks,
num_outputs=num_lq_outputs,
interval=lq_interval,
dtype=dtype,
device=device,
operations=operations,
)
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
return precompute_freqs_cis_2d(
self.hidden_size // self.num_groups,
height, width,
ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w,
device=device, dtype=dtype, **rope_opts,
)
def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs):
if not self.lq_proj.is_gate_active(i):
return s
out_idx = self.lq_proj.output_index(i)
if out_idx >= len(pid_lq_features):
return s
return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx)
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs):
if lq_latent is None:
raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
expected_c = self.lq_proj.latent_channels
if lq_latent.shape[1] != expected_c:
raise ValueError(
f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. "
f"Flux1/SD3 = 16 channels, Flux2 = 128 channels."
)
B = x.shape[0]
# Match the backbone's pad_to_patch_size (round up) so the LQ grid lines up with the patch stream.
Hs = -(-x.shape[2] // self.patch_size)
Ws = -(-x.shape[3] // self.patch_size)
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
if degrade_sigma.numel() == 1 and B > 1:
degrade_sigma = degrade_sigma.expand(B).contiguous()
lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
return super()._forward(
x, timesteps,
context=context, attention_mask=attention_mask,
transformer_options=transformer_options,
pid_lq_features=lq_features,
pid_degrade_sigma=degrade_sigma,
**kwargs,
)

View File

@ -49,8 +49,6 @@ import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.chroma_radiance.model
import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
@ -1399,53 +1397,6 @@ class ZImagePixelSpace(Lumina2):
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
self.memory_usage_factor_conds = ("ref_latents",)
class PixelDiTT2I(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
return out
class PiD(PixelDiTT2I):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
BaseModel.__init__(self, model_config, model_type, device=device,
unet_model=comfy.ldm.pixeldit.pid.PidNet)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
lq_latent = kwargs.get("lq_latent", None)
if lq_latent is not None:
out["lq_latent"] = comfy.conds.CONDRegular(lq_latent)
degrade_sigma = kwargs.get("degrade_sigma", None)
if degrade_sigma is not None:
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "lq_latent" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
lq = cond_value.cond
dim = window.dim
if dim >= lq.ndim:
return None
lq_proj = self.diffusion_model.lq_proj
ratio = lq_proj.sr_scale * lq_proj.latent_spatial_down_factor
# Map x window indices -> lq indices (deduplicated, sorted, in-bounds).
lq_size = lq.size(dim)
lq_indices = sorted({i // ratio for i in window.index_list if 0 <= i // ratio < lq_size})
if not lq_indices:
return None
idx = tuple([slice(None)] * dim + [lq_indices])
return cond_value._copy_with(lq[idx].to(device))
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

View File

@ -463,23 +463,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
# PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I.
_lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix)
if _lq_w_key in state_dict_keys:
in_ch = int(state_dict[_lq_w_key].shape[1])
_gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix)
num_gates = len({k[len(_gate_prefix):].split('.')[0]
for k in state_dict_keys if k.startswith(_gate_prefix)})
dit_config = {"image_model": "pid",
"lq_latent_channels": in_ch,
"latent_spatial_down_factor": 16 if in_ch >= 64 else 8}
if num_gates > 0:
dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates
return dit_config
if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I
return {"image_model": "pixeldit_t2i"}
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
dit_config = {}
dit_config["image_model"] = "lumina2"

View File

@ -49,7 +49,6 @@ import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.pixeldit
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
@ -1286,7 +1285,6 @@ class CLIPType(Enum):
LONGCAT_IMAGE = 26
COGVIDEOX = 27
LENS = 28
PIXELDIT = 29
@ -1530,12 +1528,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B:
if clip_type == CLIPType.PIXELDIT:
clip_target.clip = comfy.text_encoders.pixeldit.pixeldit_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer
else:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")

View File

@ -30,7 +30,6 @@ import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie
import comfy.text_encoders.cogvideo
import comfy.text_encoders.hidream_o1
import comfy.text_encoders.pixeldit
from . import supported_models_base
from . import latent_formats
@ -845,8 +844,6 @@ class Lens(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Flux2
memory_usage_factor = 4.0
supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs
vae_key_prefix = ["vae."]
@ -1204,72 +1201,6 @@ class ZImagePixelSpace(ZImage):
def get_model(self, state_dict, prefix="", device=None):
return model_base.ZImagePixelSpace(self, device=device)
class PixelDiTT2I(supported_models_base.BASE):
unet_config = {
"image_model": "pixeldit_t2i",
}
unet_extra_config = {}
sampling_settings = {
"shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
}
latent_format = latent_formats.PixelDiTPixel
memory_usage_factor = 0.04
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
return model_base.PixelDiTT2I(self, device=device)
def process_unet_state_dict(self, state_dict):
# pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim).
pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0]
out = {}
marker = ".adaLN_modulation.0."
for k, v in state_dict.items():
if k.startswith("_repa_projector") or k.startswith("net_ema."):
continue
if k.startswith("core."):
k = k[len("core."):]
elif k.startswith("net."):
k = k[len("net."):]
if "pixel_blocks." in k and marker in k:
# Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
p2 = v.shape[0] // (6 * pixel_dim)
trail = v.shape[1:] # () for bias, (in_dim,) for weight
vv = v.view(p2, 6, pixel_dim, *trail)
base, suffix = k.split(marker)
out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous()
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous()
else:
out[k] = v
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(
comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer,
comfy.text_encoders.pixeldit.PixelDiTGemma2TE,
)
class PiD(PixelDiTT2I):
unet_config = {
"image_model": "pid",
}
sampling_settings = {
"shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
}
memory_usage_factor = 0.04
def get_model(self, state_dict, prefix="", device=None):
return model_base.PiD(self, device=device)
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
@ -2180,8 +2111,6 @@ models = [
CosmosI2VPredict2,
ZImagePixelSpace,
ZImage,
PiD,
PixelDiTT2I,
Lumina2,
WAN22_T2V,
WAN21_CausalAR_T2V,

View File

@ -1,104 +0,0 @@
import torch
from comfy import sd1_clip
from .lumina2 import Gemma2BTokenizer, LuminaModel
import comfy.text_encoders.llama
class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(
device=device, layer=layer, layer_idx=layer_idx,
textmodel_json_config={}, dtype=dtype,
special_tokens={"start": 2, "pad": 0},
layer_norm_hidden_state=False,
model_class=comfy.text_encoders.llama.Gemma2_2B,
enable_attention_masks=attention_mask,
return_attention_masks=attention_mask,
model_options=model_options,
)
_PIXELDIT_CHI_PROMPT = (
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions '
"suitable for image generation. Evaluate the level of detail in the user prompt:\n"
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, "
"and spatial relationships to create vivid and concrete scenes.\n"
"- If the prompt is already detailed, refine and enhance the existing details slightly without "
"overcomplicating.\n"
"Here are examples of how to transform or refine prompts:\n"
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, "
"sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring "
"glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus "
"passing by towering glass skyscrapers.\n"
"Please generate only the enhanced description for the prompt below and avoid including any "
"additional commentary or evaluations:\n"
"User Prompt: "
)
_PIXELDIT_MAX_LENGTH = 300
_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"'
class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
name="gemma2_2b", tokenizer=Gemma2BTokenizer)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
if not text.strip():
return super().tokenize_with_weights("", return_word_ids=return_word_ids, disable_weights=True, min_length=_PIXELDIT_MAX_LENGTH)
chi_token_count = len(self.gemma2_2b.tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"])
combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text
max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2
out = super().tokenize_with_weights(combined, return_word_ids=return_word_ids,
disable_weights=True, min_length=max_length_all)
out["gemma2_2b"] = [out["gemma2_2b"][0][:max_length_all]]
return out
def untokenize(self, token_weight_pair):
return self.gemma2_2b.untokenize(token_weight_pair)
def state_dict(self):
return self.gemma2_2b.state_dict()
class PixelDiTGemma2TE(LuminaModel):
# PixelDiT's select_index: keep BOS + last 299 embeddings of the padded sequence.
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="gemma2_2b",
clip_model=PixelDiTGemma2_2BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
result = super().encode_token_weights(token_weight_pairs)
cond, pooled = result[0], result[1]
extra = result[2] if len(result) > 2 else None
if cond.shape[1] > _PIXELDIT_MAX_LENGTH:
cond = torch.cat([cond[:, :1], cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]], dim=1)
if extra is not None and "attention_mask" in extra:
am = extra["attention_mask"]
extra["attention_mask"] = torch.cat([am[..., :1], am[..., -(_PIXELDIT_MAX_LENGTH - 1):]], dim=-1)
if extra is not None:
return cond, pooled, extra
return cond, pooled
def pixeldit_te(dtype_llama=None, llama_quantization_metadata=None):
class PixelDiTTE_(PixelDiTGemma2TE):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixelDiTTE_

View File

@ -770,23 +770,6 @@ class Load3DCamera(ComfyTypeIO):
Type = CameraInfo
@comfytype(io_type="LOAD3D_MODEL_INFO")
class Load3DModelInfo(ComfyTypeIO):
class ModelTransform(TypedDict):
uuid: str
name: str
type: str
position: dict[str, float | int]
rotation: dict[str, float | int | str]
quaternion: dict[str, float | int]
scale: dict[str, float | int]
up: dict[str, float | int]
visible: bool
matrix: list[float]
Type = list[ModelTransform]
@comfytype(io_type="LOAD_3D")
class Load3D(ComfyTypeIO):
"""3D models are stored as a dictionary."""
@ -796,7 +779,6 @@ class Load3D(ComfyTypeIO):
normal: str
camera_info: Load3DCamera.CameraInfo
recording: NotRequired[str]
model_info: NotRequired[list[Load3DModelInfo.ModelTransform]]
Type = Model3DDict
@ -2309,7 +2291,6 @@ __all__ = [
"FlowControl",
"Accumulation",
"Load3DCamera",
"Load3DModelInfo",
"Load3D",
"Load3DAnimation",
"Photomaker",

View File

@ -1,32 +0,0 @@
from pydantic import BaseModel, Field
class CreateSwitchXRequest(BaseModel):
generation_type: str = Field(...)
source_uri: str = Field(...)
alpha_mode: str = Field(...)
prompt: str | None = Field(None, max_length=2000)
reference_image_uri: str | None = Field(None)
alpha_uri: str | None = Field(None)
max_resolution: int = Field(1080)
callback_url: str | None = Field(None)
idempotency_key: str | None = Field(None, max_length=256, min_length=1)
class SwitchXOutputUrls(BaseModel):
render: str | None = Field(None)
source: str | None = Field(None)
alpha: str | None = Field(None)
class SwitchXStatusResponse(BaseModel):
id: str = Field(...)
status: str = Field(...)
progress: int | None = Field(None)
generation_type: str | None = Field(None)
alpha_mode: str | None = Field(None)
output: SwitchXOutputUrls | None = Field(None)
error: str | None = Field(None)
created_at: str | None = Field(None)
modified_at: str | None = Field(None)
completed_at: str | None = Field(None)

View File

@ -158,9 +158,8 @@ class SeedanceCreateAssetResponse(BaseModel):
class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
url: str = Field(..., description="Publicly accessible URL of the asset to upload.")
url: str = Field(..., description="Publicly accessible URL of the image asset to upload.")
hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.")
asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.")
# Dollars per 1K tokens, keyed by (model_id, has_video_input).

View File

@ -1,46 +0,0 @@
"""Pydantic models for the Krea image-generation API."""
from pydantic import BaseModel, Field
class KreaMoodboard(BaseModel):
id: str = Field(...)
strength: float = Field(default=0.35, ge=-0.5, le=1.5)
class KreaImageStyleReference(BaseModel):
strength: float = Field(..., ge=-2.0, le=2.0)
url: str | None = Field(default=None)
class KreaGenerateImageRequest(BaseModel):
prompt: str = Field(...)
aspect_ratio: str = Field(...)
resolution: str = Field(...)
seed: int | None = Field(default=None)
creativity: str = Field(default="medium")
moodboards: list[KreaMoodboard] | None = Field(default=None)
image_style_references: list[KreaImageStyleReference] | None = Field(default=None)
class KreaJobResult(BaseModel):
urls: list[str] | None = Field(default=None)
style_id: str | None = Field(default=None)
class KreaJob(BaseModel):
job_id: str = Field(...)
status: str = Field(...)
created_at: str = Field(...)
completed_at: str | None = Field(default=None)
result: KreaJobResult | None = Field(default=None)
class KreaAssetResponse(BaseModel):
id: str = Field(...)
image_url: str = Field(...)
uploaded_at: str = Field(...)
width: float | None = Field(default=None)
height: float | None = Field(default=None)
size_bytes: float | None = Field(default=None)
mime_type: str | None = Field(default=None)

View File

@ -1,404 +0,0 @@
from fractions import Fraction
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types
from comfy_api_nodes.apis.beeble import (
CreateSwitchXRequest,
SwitchXStatusResponse,
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
convert_mask_to_image,
download_url_as_bytesio,
download_url_to_image_tensor,
download_url_to_video_output,
downscale_image_tensor,
downscale_video_to_max_pixels,
poll_op,
sync_op,
upload_image_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
validate_video_frame_count,
)
_MAX_PIXELS = 2_770_000
_MAX_FRAMES = 240
_MAX_PROMPT_LEN = 2000
def _validate_inputs(prompt: str | None, reference_image: Input.Image | None) -> str | None:
"""Beeble requires at least one of prompt or reference_image. Returns the cleaned prompt."""
cleaned = prompt.strip() if prompt else ""
if not cleaned and reference_image is None:
raise ValueError("At least one of 'prompt' or 'reference_image' must be provided.")
if cleaned:
validate_string(cleaned, strip_whitespace=False, max_length=_MAX_PROMPT_LEN)
return cleaned or None
async def _upload_mask_as_image(
cls: type[IO.ComfyNode],
mask: Input.Image,
*,
wait_label: str,
) -> str:
"""Encode a single-frame MASK (H, W) or (1, H, W) as a PNG and upload."""
if mask.dim() == 2:
mask = mask.unsqueeze(0)
image = convert_mask_to_image(mask[:1])
return await upload_image_to_comfyapi(
cls,
image,
mime_type="image/png",
wait_label=wait_label,
total_pixels=_MAX_PIXELS,
)
async def _upload_mask_batch_as_video(
cls: type[IO.ComfyNode],
mask: Input.Image,
*,
frame_rate: Fraction,
source_frame_count: int,
wait_label: str,
) -> str:
"""Encode a MASK batch (N, H, W) as a grayscale H.264 MP4 at frame_rate and upload.
The matte is always downscaled to the pixel budget so it stays within Beeble's limit and
keeps the same dimensions as the (similarly downscaled) source — both use the same algorithm
from the same starting dimensions, and downscaling is a no-op when already within budget.
"""
if mask.dim() == 2:
mask = mask.unsqueeze(0)
if mask.shape[0] != source_frame_count:
raise ValueError(
f"Custom alpha video frame count ({mask.shape[0]}) does not match the "
f"source video frame count ({source_frame_count}). The Beeble API requires "
"one mask per source frame."
)
images = downscale_image_tensor(convert_mask_to_image(mask), _MAX_PIXELS)
alpha_video = InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=None, frame_rate=frame_rate))
return await upload_video_to_comfyapi(cls, alpha_video, wait_label=wait_label)
def _alpha_mode_input(*, video: bool) -> IO.DynamicCombo.Input:
"""Build the alpha_mode DynamicCombo with mode-specific extra inputs."""
select_keyframe_tooltip = (
"First-frame keyframe mask. Beeble propagates this across the video." if video else "Grayscale keyframe mask."
)
custom_tooltip = (
"Per-frame grayscale mask covering the entire video. "
"Must have the same frame count as the source. "
"Connect a MASK output from SAM3_TrackToMask or similar."
if video
else "Grayscale mask to apply."
)
return IO.DynamicCombo.Input(
"alpha_mode",
tooltip=(
"Controls how SwitchX decides what to keep vs. regenerate. "
"'auto' isolates the main subject automatically. "
"'fill' regenerates the entire frame while preserving geometry. "
"'select' propagates a first-frame keyframe across the clip. "
"'custom' uses a per-frame alpha matte you provide."
),
options=[
IO.DynamicCombo.Option("auto", []),
IO.DynamicCombo.Option("fill", []),
IO.DynamicCombo.Option(
"select",
[IO.Mask.Input("alpha_keyframe", tooltip=select_keyframe_tooltip)],
),
IO.DynamicCombo.Option(
"custom",
[IO.Mask.Input("alpha_mask", tooltip=custom_tooltip)],
),
],
)
def _common_inputs(*, source: IO.Input, video: bool) -> list[IO.Input]:
return [
source,
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip=(
"Text description of the desired output (max 2000 chars). "
"At least one of 'prompt' or 'reference_image' is required."
),
),
IO.Image.Input(
"reference_image",
optional=True,
tooltip=(
"Reference image whose look (background, lighting, costume) the result "
"should adopt. At least one of 'reference_image' or 'prompt' is required."
),
),
_alpha_mode_input(video=video),
IO.Combo.Input(
"max_resolution",
options=["1080p", "720p"],
default="1080p",
tooltip="Maximum output resolution.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip=(
"Seed controls whether the node should re-run; " "results are non-deterministic regardless of seed."
),
),
]
async def _submit_and_poll(
cls: type[IO.ComfyNode],
request: CreateSwitchXRequest,
) -> SwitchXStatusResponse:
initial = await sync_op(
cls,
ApiEndpoint(path="/proxy/beeble/v1/switchx/generations", method="POST"),
response_model=SwitchXStatusResponse,
data=request,
)
return await poll_op(
cls,
ApiEndpoint(path=f"/proxy/beeble/v1/switchx/generations/{initial.id}"),
response_model=SwitchXStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
def _require_output_url(response: SwitchXStatusResponse, name: str) -> str:
if response.output is None or getattr(response.output, name) is None:
raise RuntimeError(f"Beeble job {response.id} completed without a {name!r} output URL.")
return getattr(response.output, name)
def _alpha_url(response: SwitchXStatusResponse, mode: str) -> str | None:
"""URL of the alpha matte, or None when the mode produces no separate matte.
'fill' selects the whole frame, so Beeble writes no alpha asset even though the status
response still returns a (dangling) signed URL for it — fetching it 403s with S3
AccessDenied. The other three modes ('auto', 'custom', 'select') all produce a real,
downloadable matte.
"""
if mode == "fill" or response.output is None:
return None
return response.output.alpha
class BeebleSwitchXVideoEdit(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="BeebleSwitchXVideoEdit",
display_name="Beeble SwitchX Video Edit",
category="api node/video/Beeble",
description=(
"Edit a video with Beeble SwitchX. Switches anything in the scene (background, "
"lighting, costume) while preserving the original subject's pixels and motion. "
"Provide a reference image and/or text prompt to describe the new look. "
"Max 240 frames, max ~2.77MP per frame."
),
inputs=_common_inputs(source=IO.Video.Input("video"), video=True),
outputs=[
IO.Video.Output(display_name="video"),
IO.Video.Output(
display_name="alpha",
tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.",
),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]),
expr="""
(
$rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143;
{"type":"usd","usd": $rate, "format":{"suffix":"/30 frames"}}
)
""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
prompt: str,
alpha_mode: dict,
max_resolution: str,
seed: int,
reference_image: Input.Image | None = None,
) -> IO.NodeOutput:
cleaned_prompt = _validate_inputs(prompt, reference_image)
validate_video_frame_count(video, max_frame_count=_MAX_FRAMES)
video = downscale_video_to_max_pixels(video, _MAX_PIXELS)
mode = alpha_mode["alpha_mode"]
alpha_uri: str | None = None
if mode == "select":
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe")
elif mode == "custom":
alpha_uri = await _upload_mask_batch_as_video(
cls,
alpha_mode["alpha_mask"],
frame_rate=video.get_frame_rate(),
source_frame_count=video.get_frame_count(),
wait_label="Uploading alpha video",
)
source_uri = await upload_video_to_comfyapi(cls, video, wait_label="Uploading source")
reference_uri: str | None = None
if reference_image is not None:
reference_uri = await upload_image_to_comfyapi(
cls,
reference_image,
mime_type="image/png",
wait_label="Uploading reference",
total_pixels=_MAX_PIXELS,
)
request = CreateSwitchXRequest(
generation_type="video",
source_uri=source_uri,
alpha_mode=mode,
prompt=cleaned_prompt,
reference_image_uri=reference_uri,
alpha_uri=alpha_uri,
max_resolution=1080 if max_resolution == "1080p" else 720,
)
response = await _submit_and_poll(cls, request)
render = await download_url_to_video_output(_require_output_url(response, "render"))
alpha = None
if (alpha_url := _alpha_url(response, mode)) is not None:
alpha = await download_url_to_video_output(alpha_url)
return IO.NodeOutput(render, alpha)
class BeebleSwitchXImageEdit(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="BeebleSwitchXImageEdit",
display_name="Beeble SwitchX Image Edit",
category="api node/image/Beeble",
description=(
"Edit a single image with Beeble SwitchX. Switches anything in the scene "
"(background, lighting, costume) while preserving the original subject's pixels. "
"Provide a reference image and/or text prompt to describe the new look. "
"Max ~2.77MP."
),
inputs=_common_inputs(source=IO.Image.Input("image"), video=False),
outputs=[
IO.Image.Output(display_name="image"),
IO.Mask.Output(
display_name="alpha",
tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.",
),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]),
expr="""
(
$rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143;
{"type":"usd","usd": $rate}
)
""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
prompt: str,
alpha_mode: dict,
max_resolution: str,
seed: int,
reference_image: Input.Image | None = None,
) -> IO.NodeOutput:
cleaned_prompt = _validate_inputs(prompt, reference_image)
image = downscale_image_tensor(image, _MAX_PIXELS)
mode = alpha_mode["alpha_mode"]
alpha_uri: str | None = None
if mode == "select":
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe")
elif mode == "custom":
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_mask"], wait_label="Uploading alpha")
source_uri = await upload_image_to_comfyapi(
cls,
image,
mime_type="image/png",
wait_label="Uploading source",
total_pixels=None,
)
reference_uri: str | None = None
if reference_image is not None:
reference_uri = await upload_image_to_comfyapi(
cls,
reference_image,
mime_type="image/png",
wait_label="Uploading reference",
total_pixels=_MAX_PIXELS,
)
request = CreateSwitchXRequest(
generation_type="image",
source_uri=source_uri,
alpha_mode=mode,
prompt=cleaned_prompt,
reference_image_uri=reference_uri,
alpha_uri=alpha_uri,
max_resolution=1080 if max_resolution == "1080p" else 720,
)
response = await _submit_and_poll(cls, request)
render = await download_url_to_image_tensor(_require_output_url(response, "render"))
alpha_mask = None
if (alpha_url := _alpha_url(response, mode)) is not None:
alpha_image = bytesio_to_image_tensor(await download_url_as_bytesio(alpha_url), mode="L")
alpha_mask = alpha_image.squeeze(-1) if alpha_image.dim() == 4 else alpha_image
return IO.NodeOutput(render, alpha_mask)
class BeebleExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
BeebleSwitchXVideoEdit,
BeebleSwitchXImageEdit,
]
async def comfy_entrypoint() -> BeebleExtension:
return BeebleExtension()

View File

@ -2,12 +2,11 @@ import hashlib
import logging
import math
import re
from io import BytesIO
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bytedance import (
RECOMMENDED_PRESETS,
RECOMMENDED_PRESETS_SEEDREAM_4,
@ -309,26 +308,6 @@ async def _seedance_virtual_library_upload_image_asset(
return f"asset://{create_resp.asset_id}"
async def _seedance_virtual_library_upload_video_asset(
cls: type[IO.ComfyNode],
video: Input.Video,
*,
wait_label: str = "Uploading video",
) -> str:
buf = BytesIO()
video.save_to(buf, format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264)
video_hash = hashlib.sha256(buf.getbuffer()).hexdigest()
public_url = await upload_video_to_comfyapi(cls, video, wait_label=wait_label)
create_resp = await sync_op(
cls,
ApiEndpoint(path="/proxy/seedance/virtual-library/assets", method="POST"),
response_model=SeedanceCreateAssetResponse,
data=SeedanceVirtualLibraryCreateAssetRequest(url=public_url, hash=video_hash, asset_type="Video"),
)
await _wait_for_asset_active(cls, create_resp.asset_id, group_id="virtual-library")
return f"asset://{create_resp.asset_id}"
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
@ -2127,7 +2106,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
content.append(
TaskVideoContent(
video_url=TaskVideoContentUrl(
url=await _seedance_virtual_library_upload_video_asset(
url=await upload_video_to_comfyapi(
cls,
reference_videos[key],
wait_label=f"Uploading video {i}",

View File

@ -1,290 +0,0 @@
"""Krea image-generation nodes."""
import re
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.krea import (
KreaAssetResponse,
KreaGenerateImageRequest,
KreaImageStyleReference,
KreaJob,
KreaMoodboard,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
poll_op,
sync_op,
tensor_to_bytesio,
validate_string,
)
class KreaIO:
STYLE_REF = "KREA_STYLE_REF"
async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Image) -> str:
"""Upload an image to Krea's /assets endpoint and return the Krea-hosted image URL."""
img_io = tensor_to_bytesio(image, total_pixels=2048 * 2048, mime_type="image/png")
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/krea/assets", method="POST"),
response_model=KreaAssetResponse,
files=[("file", (img_io.name, img_io, "image/png"))],
content_type="multipart/form-data",
max_retries=1,
wait_label="Uploading reference",
)
return response.image_url
_MODEL_MEDIUM = "Krea 2 Medium"
_MODEL_LARGE = "Krea 2 Large"
_MODEL_ENDPOINTS: dict[str, str] = {
_MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium",
_MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large",
}
_ASPECT_RATIOS = ["1:1", "4:3", "3:2", "16:9", "2.35:1", "4:5", "2:3", "9:16"]
_RESOLUTIONS = ["1K"]
_CREATIVITY_LEVELS = ["raw", "low", "medium", "high"]
_KREA_QUEUED_STATUSES = ["backlogged", "queued", "scheduled"]
_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
def _krea_model_inputs() -> list:
"""Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo."""
return [
IO.Combo.Input(
"aspect_ratio",
options=_ASPECT_RATIOS,
tooltip="Output aspect ratio.",
),
IO.Combo.Input(
"resolution",
options=_RESOLUTIONS,
tooltip="Resolution scale.",
),
IO.Combo.Input(
"creativity",
options=_CREATIVITY_LEVELS,
default="medium",
tooltip="Prompt interpretation strength: raw stays closest to the prompt; high is most creative.",
),
IO.String.Input(
"moodboard_id",
default="",
tooltip="Optional Krea moodboard UUID (e.g. from the Krea website). "
"Leave empty to disable. Only one moodboard is supported per request.",
optional=True,
),
IO.Float.Input(
"moodboard_strength",
default=0.35,
min=-0.5,
max=1.5,
step=0.05,
tooltip="Moodboard influence; ignored when moodboard_id is empty.",
optional=True,
),
IO.Custom(KreaIO.STYLE_REF).Input(
"style_reference",
optional=True,
tooltip="Optional chain of style references (max 10) from Krea 2 Style Reference nodes.",
),
]
class Krea2ImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Krea2ImageNode",
display_name="Krea 2 Image",
category="api node/image/Krea",
description=(
"Generate images via Krea 2 — pick Medium (expressive illustrations) or "
"Large (expressive photorealism). Supports an optional moodboard and up "
"to 10 chained image style references."
),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the image.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()),
IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()),
],
tooltip="Krea 2 Medium is best for expressive illustrations; "
"Krea 2 Large is best for expressive photorealism.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Random seed for reproducibility.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model", "model.moodboard_id"],
inputs=["model.style_reference"],
),
expr="""
(
$isLarge := widgets.model = "krea 2 large";
$hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0;
$hasStyle := $lookup(inputs, "model.style_reference").connected;
$usd := $hasMoodboard
? ($isLarge ? 0.07 : 0.04)
: ($hasStyle
? ($isLarge ? 0.065 : 0.035)
: ($isLarge ? 0.06 : 0.03));
{"type":"usd","usd": $usd}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1)
model_choice = model["model"]
endpoint_path = _MODEL_ENDPOINTS.get(model_choice)
if endpoint_path is None:
raise ValueError(f"Unknown Krea 2 model: {model_choice!r}")
moodboards: list[KreaMoodboard] | None = None
mb_id = (model.get("moodboard_id") or "").strip()
if mb_id:
if not _UUID_RE.match(mb_id):
raise ValueError(f"moodboard_id must be a UUID (received {mb_id!r}); copy it from the Krea website.")
mb_strength = model.get("moodboard_strength")
moodboards = [KreaMoodboard(id=mb_id, strength=0.35 if mb_strength is None else float(mb_strength))]
style_reference = model.get("style_reference")
image_style_references: list[KreaImageStyleReference] | None = None
if style_reference:
if len(style_reference) > 10:
raise ValueError(f"Krea 2 accepts at most 10 image_style_references; received {len(style_reference)}.")
image_style_references = [
KreaImageStyleReference(url=ref["url"], strength=float(ref["strength"])) for ref in style_reference
]
initial = await sync_op(
cls,
ApiEndpoint(path=endpoint_path, method="POST"),
response_model=KreaJob,
data=KreaGenerateImageRequest(
prompt=prompt,
aspect_ratio=model["aspect_ratio"],
resolution=model["resolution"],
seed=seed,
creativity=model["creativity"],
moodboards=moodboards,
image_style_references=image_style_references,
),
)
job = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/krea/jobs/{initial.job_id}", method="GET"),
response_model=KreaJob,
status_extractor=lambda r: r.status,
queued_statuses=_KREA_QUEUED_STATUSES,
)
if not job.result or not job.result.urls:
raise RuntimeError(f"Krea 2 job {job.job_id} completed without any image URLs.")
image = await download_url_to_image_tensor(job.result.urls[0])
return IO.NodeOutput(image)
class Krea2StyleReferenceNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Krea2StyleReferenceNode",
display_name="Krea 2 Style Reference",
category="api node/image/Krea",
description=(
"Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 "
"Style Reference nodes (max 10) and feed the final `style_reference` output "
"into Krea 2 Image. Each image is uploaded to ComfyAPI storage and passed as URL."
),
inputs=[
IO.Image.Input(
"image",
tooltip="Reference image whose style influences the generation.",
),
IO.Float.Input(
"strength",
default=1.0,
min=-2.0,
max=2.0,
step=0.05,
tooltip="Reference strength; negative values invert the style influence.",
),
IO.Custom(KreaIO.STYLE_REF).Input(
"style_reference",
optional=True,
tooltip="Optional incoming chain of style references; this node appends one more.",
),
],
outputs=[IO.Custom(KreaIO.STYLE_REF).Output(display_name="style_reference")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@classmethod
async def execute(
cls,
image: Input.Image,
strength: float,
style_reference: list[dict] | None = None,
) -> IO.NodeOutput:
chain: list[dict] = list(style_reference) if style_reference else []
if len(chain) >= 10:
raise ValueError("Krea 2 accepts at most 10 image_style_references in one generation.")
url = await _upload_image_to_krea_assets(cls, image)
chain.append({"url": url, "strength": float(strength)})
return IO.NodeOutput(chain)
class KreaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
Krea2ImageNode,
Krea2StyleReferenceNode,
]
async def comfy_entrypoint() -> KreaExtension:
return KreaExtension()

View File

@ -86,7 +86,7 @@ class _PollUIState:
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
async def sync_op(

View File

@ -47,7 +47,6 @@ class Load3D(IO.ComfyNode):
IO.Load3DCamera.Output(display_name="camera_info"),
IO.Video.Output(display_name="recording_video"),
IO.File3DAny.Output(display_name="model_3d"),
IO.Load3DModelInfo.Output(display_name="model_info"),
],
)
@ -70,8 +69,7 @@ class Load3D(IO.ComfyNode):
video = InputImpl.VideoFromFile(recording_video_path)
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
model_info = image.get('model_info', [])
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d, model_info)
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d)
process = execute # TODO: remove

View File

@ -226,20 +226,10 @@ def get_noise_mask(latent):
noise_mask = noise_mask.clone()
return noise_mask
def get_keyframe_idxs(cond, latent_shape=None):
def get_keyframe_idxs(cond):
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
if keyframe_idxs is None:
return None, 0
# Get number of keyframes from latent_shape or guide_attention_entries if available
if latent_shape is not None and len(latent_shape) == 5:
tokens_per_frame = latent_shape[-2] * latent_shape[-1]
num_keyframes = keyframe_idxs.shape[2] // tokens_per_frame
return keyframe_idxs, num_keyframes
entries = conditioning_get_any_value(cond, "guide_attention_entries", None)
if entries:
num_keyframes = sum(e["latent_shape"][0] for e in entries)
return keyframe_idxs, num_keyframes
# fallback, may under-count if keyframes share t-start
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
return keyframe_idxs, num_keyframes
@ -332,9 +322,9 @@ class LTXVAddGuide(io.ComfyNode):
return factor
@classmethod
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors, latent_shape=None):
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond, latent_shape)
_, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
if guide_length > 1 and frame_idx != 0:
@ -446,7 +436,7 @@ class LTXVAddGuide(io.ComfyNode):
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
resolved_frame_idx = frame_idx
if frame_idx < 0:
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
_, num_keyframes = get_keyframe_idxs(positive)
resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0)
causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1
@ -464,7 +454,7 @@ class LTXVAddGuide(io.ComfyNode):
if latent_downscale_factor > 1:
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors, latent_shape=latent_image.shape)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
positive, negative, latent_image, noise_mask = cls.append_keyframe(
@ -516,7 +506,7 @@ class LTXVCropGuides(io.ComfyNode):
latent_image = latent["samples"].clone()
noise_mask = get_noise_mask(latent)
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
_, num_keyframes = get_keyframe_idxs(positive)
if num_keyframes == 0:
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)

View File

@ -1,32 +1,32 @@
from comfy import model_management
from comfy_api.latest import ComfyExtension, IO
from typing_extensions import override
import math
class LTXVLatentUpsampler(IO.ComfyNode):
class LTXVLatentUpsampler:
"""
Upsamples a video latent by a factor of 2.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="LTXVLatentUpsampler",
category="latent/video",
is_experimental=True,
inputs=[
IO.Latent.Input("samples"),
IO.LatentUpscaleModel.Input("upscale_model"),
IO.Vae.Input("vae"),
],
outputs=[
IO.Latent.Output(),
],
)
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT",),
"upscale_model": ("LATENT_UPSCALE_MODEL",),
"vae": ("VAE",),
}
}
@classmethod
def execute(cls, samples, upscale_model, vae) -> IO.NodeOutput:
RETURN_TYPES = ("LATENT",)
FUNCTION = "upsample_latent"
CATEGORY = "latent/video"
EXPERIMENTAL = True
def upsample_latent(
self,
samples: dict,
upscale_model,
vae,
) -> tuple:
"""
Upsample the input latent using the provided model.
@ -34,6 +34,7 @@ class LTXVLatentUpsampler(IO.ComfyNode):
samples (dict): Input latent samples
upscale_model (LatentUpsampler): Loaded upscale model
vae: VAE model for normalization
auto_tiling (bool): Whether to automatically tile the input for processing
Returns:
tuple: Tuple containing the upsampled latent
@ -66,16 +67,9 @@ class LTXVLatentUpsampler(IO.ComfyNode):
return_dict = samples.copy()
return_dict["samples"] = upsampled_latents
return_dict.pop("noise_mask", None)
return IO.NodeOutput(return_dict)
upsample_latent = execute # TODO: remove
return (return_dict,)
class LTXVLatentUpsamplerExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [LTXVLatentUpsampler]
async def comfy_entrypoint() -> LTXVLatentUpsamplerExtension:
return LTXVLatentUpsamplerExtension()
NODE_CLASS_MAPPINGS = {
"LTXVLatentUpsampler": LTXVLatentUpsampler,
}

View File

@ -1,55 +0,0 @@
"""PiD (Pixel Diffusion Decoder) node"""
import torch
from typing_extensions import override
import node_helpers
import comfy.latent_formats
from comfy_api.latest import ComfyExtension, io
class PiDConditioning(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="PiDConditioning",
display_name="PiD Conditioning",
category="advanced/conditioning",
description=(
"Attaches a latent and a degrade_sigma scalar to a CONDITIONING for PiD decoding/upscaling"
),
inputs=[
io.Conditioning.Input("positive"),
io.Latent.Input("latent", tooltip="latent (from VAEEncode or a KSampler)."),
io.Combo.Input("latent_format", options=["flux", "sd3"], default="flux",
tooltip="Flux1 and Flux2 latents auto-detected from channel dim, sd3 has to be selected manually."),
io.Float.Input(
"degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01,
tooltip="0 = clean latent. Increase to denoise corrupted latent outputs.",
),
],
outputs=[io.Conditioning.Output()],
)
@classmethod
def execute(cls, positive, latent, latent_format: str, degrade_sigma: float) -> io.NodeOutput:
samples = latent["samples"]
if latent_format == "flux":
fmt_cls = comfy.latent_formats.Flux2 if samples.shape[1] == 128 else comfy.latent_formats.Flux
else:
fmt_cls = comfy.latent_formats.SD3
lq_latent = fmt_cls().process_in(samples)
sigma_t = torch.tensor([float(degrade_sigma)], dtype=torch.float32)
return io.NodeOutput(node_helpers.conditioning_set_values(
positive, {"lq_latent": lq_latent, "degrade_sigma": sigma_t},
))
class PiDExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [PiDConditioning]
async def comfy_entrypoint() -> PiDExtension:
return PiDExtension()

View File

@ -969,7 +969,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -979,7 +979,7 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B\nlens: gpt-oss-20b\n pixeldit: gemma 2 2B elm"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B\nlens: gpt-oss-20b"
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
@ -2420,7 +2420,6 @@ async def init_builtin_extra_nodes():
"nodes_context_windows.py",
"nodes_qwen.py",
"nodes_chroma_radiance.py",
"nodes_pid.py",
"nodes_model_patch.py",
"nodes_easycache.py",
"nodes_audio_encoder.py",

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.44.19
comfyui-workflow-templates==0.9.85
comfyui-workflow-templates==0.9.82
comfyui-embedded-docs==0.5.1
torch
torchsde