mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-06 10:17:59 +08:00
Compare commits
22 Commits
dev/Combo-
...
causal_for
| Author | SHA1 | Date | |
|---|---|---|---|
| efee285099 | |||
| fc303cb2cf | |||
| 8ad8d101a1 | |||
| 087844bd50 | |||
| cac5120f96 | |||
| 38514045ab | |||
| f274674f14 | |||
| a4377d588e | |||
| df7398a692 | |||
| d1998480b0 | |||
| 92f937daa6 | |||
| 2aacf304fa | |||
| bbce5a8c75 | |||
| cbd2b17c67 | |||
| c33d26c283 | |||
| f3ea976cba | |||
| 5538f62b0b | |||
| 2806163f6e | |||
| cea8d0925f | |||
| b138133ffa | |||
| 025e6792ee | |||
| 867b8d2408 |
@ -91,6 +91,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
|
||||
|
||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
||||
parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
|
||||
|
||||
class LatentPreviewMethod(enum.Enum):
|
||||
NoPreviews = "none"
|
||||
|
||||
@ -1810,3 +1810,119 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
||||
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
num_frame_per_block=1):
|
||||
"""
|
||||
Autoregressive video sampler: block-by-block denoising with KV cache
|
||||
and flow-match re-noising for Causal Forcing / Self-Forcing models.
|
||||
|
||||
Requires a Causal-WAN compatible model (diffusion_model must expose
|
||||
init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
|
||||
|
||||
All AR-loop parameters are passed via the SamplerARVideo node, not read
|
||||
from the checkpoint or transformer_options.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
model_options = extra_args.get("model_options", {})
|
||||
transformer_options = model_options.get("transformer_options", {})
|
||||
|
||||
if x.ndim != 5:
|
||||
raise ValueError(
|
||||
f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. "
|
||||
"This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)."
|
||||
)
|
||||
|
||||
inner_model = model.inner_model.inner_model
|
||||
causal_model = inner_model.diffusion_model
|
||||
|
||||
if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")):
|
||||
raise TypeError(
|
||||
"ar_video sampler requires a Causal-WAN compatible model whose diffusion_model "
|
||||
"exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint "
|
||||
"does not support this interface — choose a different sampler."
|
||||
)
|
||||
|
||||
seed = extra_args.get("seed", 0)
|
||||
|
||||
bs, c, lat_t, lat_h, lat_w = x.shape
|
||||
frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
|
||||
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
|
||||
device = x.device
|
||||
model_dtype = inner_model.get_dtype()
|
||||
|
||||
kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype)
|
||||
crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype)
|
||||
|
||||
output = torch.zeros_like(x)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
current_start_frame = 0
|
||||
|
||||
# I2V: seed KV cache with the initial image latent before the denoising loop
|
||||
initial_latent = ar_config.get("initial_latent", None)
|
||||
if initial_latent is not None:
|
||||
initial_latent = inner_model.process_latent_in(initial_latent).to(device=device, dtype=model_dtype)
|
||||
n_init = initial_latent.shape[2]
|
||||
output[:, :, :n_init] = initial_latent
|
||||
|
||||
ar_state = {"start_frame": 0, "kv_caches": kv_caches, "crossattn_caches": crossattn_caches}
|
||||
transformer_options["ar_state"] = ar_state
|
||||
zero_sigma = sigmas.new_zeros([1])
|
||||
_ = model(initial_latent, zero_sigma * s_in, **extra_args)
|
||||
|
||||
current_start_frame = n_init
|
||||
remaining = lat_t - n_init
|
||||
num_blocks = -(-remaining // num_frame_per_block)
|
||||
|
||||
num_sigma_steps = len(sigmas) - 1
|
||||
total_real_steps = num_blocks * num_sigma_steps
|
||||
step_count = 0
|
||||
|
||||
try:
|
||||
for block_idx in trange(num_blocks, disable=disable):
|
||||
bf = min(num_frame_per_block, lat_t - current_start_frame)
|
||||
fs, fe = current_start_frame, current_start_frame + bf
|
||||
noisy_input = x[:, :, fs:fe]
|
||||
|
||||
ar_state = {
|
||||
"start_frame": current_start_frame,
|
||||
"kv_caches": kv_caches,
|
||||
"crossattn_caches": crossattn_caches,
|
||||
}
|
||||
transformer_options["ar_state"] = ar_state
|
||||
|
||||
for i in range(num_sigma_steps):
|
||||
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
|
||||
|
||||
if callback is not None:
|
||||
scaled_i = step_count * num_sigma_steps // total_real_steps
|
||||
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
|
||||
"sigma_hat": sigmas[i], "denoised": denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
noisy_input = denoised
|
||||
else:
|
||||
sigma_next = sigmas[i + 1]
|
||||
torch.manual_seed(seed + block_idx * 1000 + i)
|
||||
fresh_noise = torch.randn_like(denoised)
|
||||
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
|
||||
|
||||
for cache in kv_caches:
|
||||
cache["end"] -= bf * frame_seq_len
|
||||
|
||||
step_count += 1
|
||||
|
||||
output[:, :, fs:fe] = noisy_input
|
||||
|
||||
for cache in kv_caches:
|
||||
cache["end"] -= bf * frame_seq_len
|
||||
zero_sigma = sigmas.new_zeros([1])
|
||||
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
|
||||
|
||||
current_start_frame += bf
|
||||
finally:
|
||||
transformer_options.pop("ar_state", None)
|
||||
|
||||
return output
|
||||
|
||||
276
comfy/ldm/wan/ar_model.py
Normal file
276
comfy/ldm/wan/ar_model.py
Normal file
@ -0,0 +1,276 @@
|
||||
"""
|
||||
CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for
|
||||
autoregressive (frame-by-frame) video generation via Causal Forcing.
|
||||
|
||||
Weight-compatible with the standard WanModel -- same layer names, same shapes.
|
||||
The difference is purely in the forward pass: this model processes one temporal
|
||||
block at a time and maintains a KV cache across blocks.
|
||||
|
||||
Reference: https://github.com/thu-ml/Causal-Forcing
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.wan.model import (
|
||||
sinusoidal_embedding_1d,
|
||||
repeat_e,
|
||||
WanModel,
|
||||
WanAttentionBlock,
|
||||
)
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class CausalWanSelfAttention(nn.Module):
|
||||
"""Self-attention with KV cache support for autoregressive inference."""
|
||||
|
||||
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,
|
||||
eps=1e-6, operation_settings={}):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
ops = operation_settings.get("operations")
|
||||
device = operation_settings.get("device")
|
||||
dtype = operation_settings.get("dtype")
|
||||
|
||||
self.q = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.k = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.v = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.o = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, freqs, kv_cache=None, transformer_options={}):
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs)
|
||||
k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs)
|
||||
v = self.v(x).view(b, s, n, d)
|
||||
|
||||
if kv_cache is None:
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
v.view(b, s, n * d),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
end = kv_cache["end"]
|
||||
new_end = end + s
|
||||
|
||||
# Roped K and plain V go into cache
|
||||
kv_cache["k"][:, end:new_end] = k
|
||||
kv_cache["v"][:, end:new_end] = v
|
||||
kv_cache["end"] = new_end
|
||||
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
kv_cache["k"][:, :new_end].view(b, new_end, n * d),
|
||||
kv_cache["v"][:, :new_end].view(b, new_end, n * d),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class CausalWanAttentionBlock(WanAttentionBlock):
|
||||
"""Transformer block with KV-cached self-attention and cross-attention caching."""
|
||||
|
||||
def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps,
|
||||
operation_settings=operation_settings)
|
||||
self.self_attn = CausalWanSelfAttention(
|
||||
dim, num_heads, window_size, qk_norm, eps,
|
||||
operation_settings=operation_settings)
|
||||
|
||||
def forward(self, x, e, freqs, context, context_img_len=257,
|
||||
kv_cache=None, crossattn_cache=None, transformer_options={}):
|
||||
if e.ndim < 4:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
else:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
|
||||
|
||||
# Self-attention with optional KV cache
|
||||
x = x.contiguous()
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, kv_cache=kv_cache, transformer_options=transformer_options)
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
del y
|
||||
|
||||
# Cross-attention with optional caching
|
||||
if crossattn_cache is not None and crossattn_cache.get("is_init"):
|
||||
q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x)))
|
||||
x_ca = optimized_attention(
|
||||
q, crossattn_cache["k"], crossattn_cache["v"],
|
||||
heads=self.num_heads, transformer_options=transformer_options)
|
||||
x = x + self.cross_attn.o(x_ca)
|
||||
else:
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
if crossattn_cache is not None:
|
||||
crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context))
|
||||
crossattn_cache["v"] = self.cross_attn.v(context)
|
||||
crossattn_cache["is_init"] = True
|
||||
|
||||
# FFN
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
|
||||
|
||||
class CausalWanModel(WanModel):
|
||||
"""
|
||||
Wan 2.1 diffusion backbone with causal KV-cache support.
|
||||
|
||||
Same weight structure as WanModel -- loads identical state dicts.
|
||||
Adds forward_block() for frame-by-frame autoregressive inference.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='t2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None):
|
||||
super().__init__(
|
||||
model_type=model_type, patch_size=patch_size, text_len=text_len,
|
||||
in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
|
||||
text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
|
||||
num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
|
||||
cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
|
||||
wan_attn_block_class=CausalWanAttentionBlock,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward_block(self, x, timestep, context, start_frame,
|
||||
kv_caches, crossattn_caches, clip_fea=None):
|
||||
"""
|
||||
Forward one temporal block for autoregressive inference.
|
||||
|
||||
Args:
|
||||
x: [B, C, block_frames, H, W] input latent for the current block
|
||||
timestep: [B, block_frames] per-frame timesteps
|
||||
context: [B, L, text_dim] raw text embeddings (pre-text_embedding)
|
||||
start_frame: temporal frame index for RoPE offset
|
||||
kv_caches: list of per-layer KV cache dicts
|
||||
crossattn_caches: list of per-layer cross-attention cache dicts
|
||||
clip_fea: optional CLIP features for I2V
|
||||
|
||||
Returns:
|
||||
flow_pred: [B, C_out, block_frames, H, W] flow prediction
|
||||
"""
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
bs, c, t, h, w = x.shape
|
||||
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# Per-frame time embedding
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
|
||||
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
# Text embedding (reuses crossattn_cache after first block)
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None and self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea)
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
# RoPE for current block's temporal position
|
||||
freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype)
|
||||
|
||||
# Transformer blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x, e=e0, freqs=freqs, context=context,
|
||||
context_img_len=context_img_len,
|
||||
kv_cache=kv_caches[i],
|
||||
crossattn_cache=crossattn_caches[i])
|
||||
|
||||
# Head
|
||||
x = self.head(x, e)
|
||||
|
||||
# Unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x[:, :, :t, :h, :w]
|
||||
|
||||
def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
|
||||
"""Create fresh KV caches for all layers."""
|
||||
caches = []
|
||||
for _ in range(self.num_layers):
|
||||
caches.append({
|
||||
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||
"end": 0,
|
||||
})
|
||||
return caches
|
||||
|
||||
def init_crossattn_caches(self, batch_size, device, dtype):
|
||||
"""Create fresh cross-attention caches for all layers."""
|
||||
caches = []
|
||||
for _ in range(self.num_layers):
|
||||
caches.append({"is_init": False})
|
||||
return caches
|
||||
|
||||
def reset_kv_caches(self, kv_caches):
|
||||
"""Reset KV caches to empty (reuse allocated memory)."""
|
||||
for cache in kv_caches:
|
||||
cache["end"] = 0
|
||||
|
||||
def reset_crossattn_caches(self, crossattn_caches):
|
||||
"""Reset cross-attention caches."""
|
||||
for cache in crossattn_caches:
|
||||
cache["is_init"] = False
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.dim // self.num_heads
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
ar_state = transformer_options.get("ar_state")
|
||||
if ar_state is not None:
|
||||
bs = x.shape[0]
|
||||
block_frames = x.shape[2]
|
||||
t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames)
|
||||
return self.forward_block(
|
||||
x=x, timestep=t_per_frame, context=context,
|
||||
start_frame=ar_state["start_frame"],
|
||||
kv_caches=ar_state["kv_caches"],
|
||||
crossattn_caches=ar_state["crossattn_caches"],
|
||||
clip_fea=clip_fea,
|
||||
)
|
||||
|
||||
return super().forward(x, timestep, context, clip_fea=clip_fea,
|
||||
time_dim_concat=time_dim_concat,
|
||||
transformer_options=transformer_options, **kwargs)
|
||||
@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.wan.ar_model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
@ -1365,6 +1366,13 @@ class WAN21(BaseModel):
|
||||
return out
|
||||
|
||||
|
||||
class WAN21_CausalAR(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.wan.ar_model.CausalWanModel)
|
||||
self.image_to_video = False
|
||||
|
||||
|
||||
class WAN21_Vace(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
from comfy_kitchen.tensor import (
|
||||
@ -21,7 +23,15 @@ try:
|
||||
ck.registry.disable("cuda")
|
||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
||||
|
||||
ck.registry.disable("triton")
|
||||
if args.enable_triton_backend:
|
||||
try:
|
||||
import triton
|
||||
logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
|
||||
except ImportError as e:
|
||||
logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
|
||||
ck.registry.disable("triton")
|
||||
else:
|
||||
ck.registry.disable("triton")
|
||||
for k, v in ck.list_backends().items():
|
||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||
except ImportError as e:
|
||||
|
||||
@ -1167,6 +1167,25 @@ class WAN21_T2V(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
||||
|
||||
class WAN21_CausalAR_T2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "t2v",
|
||||
"causal_ar": True,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 5.0,
|
||||
}
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.unet_config.pop("causal_ar", None)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.WAN21_CausalAR(self, device=device)
|
||||
|
||||
|
||||
class WAN21_I2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1929,6 +1948,7 @@ models = [
|
||||
ZImage,
|
||||
Lumina2,
|
||||
WAN22_T2V,
|
||||
WAN21_CausalAR_T2V,
|
||||
WAN21_T2V,
|
||||
WAN21_I2V,
|
||||
WAN21_FunControl2V,
|
||||
|
||||
@ -43,67 +43,7 @@ class UploadType(str, Enum):
|
||||
model = "file_upload"
|
||||
|
||||
|
||||
class RemoteItemSchema:
|
||||
"""Describes how to map API response objects to rich dropdown items.
|
||||
|
||||
All *_field parameters use dot-path notation (e.g. ``"labels.gender"``).
|
||||
``label_field`` and ``description_field`` additionally support template strings
|
||||
with ``{field}`` placeholders (e.g. ``"{name} ({labels.accent})"``).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
value_field: str,
|
||||
label_field: str,
|
||||
preview_url_field: str | None = None,
|
||||
preview_type: Literal["image", "video", "audio"] = "image",
|
||||
description_field: str | None = None,
|
||||
search_fields: list[str] | None = None,
|
||||
):
|
||||
if preview_type not in ("image", "video", "audio"):
|
||||
raise ValueError(
|
||||
f"RemoteItemSchema: 'preview_type' must be 'image', 'video', or 'audio'; got {preview_type!r}."
|
||||
)
|
||||
if search_fields is not None:
|
||||
for f in search_fields:
|
||||
if "{" in f or "}" in f:
|
||||
raise ValueError(
|
||||
f"RemoteItemSchema: 'search_fields' must be dot-paths, not template strings (got {f!r})."
|
||||
)
|
||||
self.value_field = value_field
|
||||
"""Dot-path to the unique identifier within each item.
|
||||
This value is stored in the widget and passed to execute()."""
|
||||
self.label_field = label_field
|
||||
"""Dot-path to the display name, or a template string with {field} placeholders."""
|
||||
self.preview_url_field = preview_url_field
|
||||
"""Dot-path to a preview media URL. If None, no preview is shown."""
|
||||
self.preview_type = preview_type
|
||||
"""How to render the preview: "image", "video", or "audio"."""
|
||||
self.description_field = description_field
|
||||
"""Optional dot-path or template for a subtitle line shown below the label."""
|
||||
self.search_fields = search_fields
|
||||
"""Dot-paths to fields included in the search index. When unset, search falls back to
|
||||
the resolved label (i.e. ``label_field`` after template substitution). Note that template
|
||||
label strings (e.g. ``"{first} {last}"``) are not valid path entries here — list the
|
||||
underlying paths (``["first", "last"]``) instead."""
|
||||
|
||||
def as_dict(self):
|
||||
return prune_dict({
|
||||
"value_field": self.value_field,
|
||||
"label_field": self.label_field,
|
||||
"preview_url_field": self.preview_url_field,
|
||||
"preview_type": self.preview_type,
|
||||
"description_field": self.description_field,
|
||||
"search_fields": self.search_fields,
|
||||
})
|
||||
|
||||
|
||||
class RemoteOptions:
|
||||
"""Plain remote combo: fetches a list of strings/objects and populates a standard dropdown.
|
||||
|
||||
Use this for lightweight lists from endpoints that return a bare array (or an array under
|
||||
``response_key``). For rich dropdowns with previews, search, filtering, or pagination,
|
||||
use :class:`RemoteComboOptions` and the ``remote_combo=`` parameter on ``Combo.Input``.
|
||||
"""
|
||||
def __init__(self, route: str, refresh_button: bool, control_after_refresh: Literal["first", "last"]="first",
|
||||
timeout: int=None, max_retries: int=None, refresh: int=None):
|
||||
self.route = route
|
||||
@ -130,80 +70,6 @@ class RemoteOptions:
|
||||
})
|
||||
|
||||
|
||||
class RemoteComboOptions:
|
||||
"""Rich remote combo: populates a Vue dropdown with previews, search, and filtering.
|
||||
|
||||
Attached to a :class:`Combo.Input` via ``remote_combo=`` (not ``remote=``). Requires an
|
||||
``item_schema`` describing how to map API response objects to dropdown items.
|
||||
|
||||
Response-shape contract: the endpoint returns the full items array in a single response
|
||||
(either at the top level, or at the dot-path given by ``response_key``). Backing endpoints
|
||||
that paginate upstream are expected to aggregate and cache server-side.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
route: str,
|
||||
item_schema: RemoteItemSchema,
|
||||
refresh_button: bool = True,
|
||||
auto_select: Literal["first", "last"] | None = None,
|
||||
timeout: int | None = None,
|
||||
max_retries: int | None = None,
|
||||
refresh: int | None = None,
|
||||
response_key: str | None = None,
|
||||
):
|
||||
if auto_select is not None and auto_select not in ("first", "last"):
|
||||
raise ValueError(
|
||||
f"RemoteComboOptions: 'auto_select' must be 'first', 'last', or None; got {auto_select!r}."
|
||||
)
|
||||
if refresh is not None and 0 < refresh < 128:
|
||||
raise ValueError(
|
||||
f"RemoteComboOptions: 'refresh' must be >= 128 (ms TTL) or <= 0 (cache never expires); got {refresh}."
|
||||
)
|
||||
if timeout is not None and timeout < 0:
|
||||
raise ValueError(
|
||||
f"RemoteComboOptions: 'timeout' must be >= 0 (got {timeout})."
|
||||
)
|
||||
if max_retries is not None and max_retries < 0:
|
||||
raise ValueError(
|
||||
f"RemoteComboOptions: 'max_retries' must be >= 0 (got {max_retries})."
|
||||
)
|
||||
if not route.startswith("/"):
|
||||
raise ValueError(
|
||||
f"RemoteComboOptions: 'route' must be a relative path starting with '/'; got {route!r}."
|
||||
)
|
||||
self.route = route
|
||||
"""Relative path to the remote source (must start with ``/``). The frontend resolves this
|
||||
against the comfy-api base URL and injects auth headers; absolute URLs are rejected."""
|
||||
self.item_schema = item_schema
|
||||
"""Required: describes how each API response object maps to a dropdown item."""
|
||||
self.refresh_button = refresh_button
|
||||
"""Specifies whether to show a refresh button next to the widget."""
|
||||
self.auto_select = auto_select
|
||||
"""Fallback item to select when the widget's value is empty. Never overrides an existing
|
||||
selection. Default None means no fallback."""
|
||||
self.timeout = timeout
|
||||
"""Maximum time to wait for a response, in milliseconds."""
|
||||
self.max_retries = max_retries
|
||||
"""Maximum number of retries before aborting the request. Default None uses the frontend's built-in limit."""
|
||||
self.refresh = refresh
|
||||
"""TTL of the cached value in milliseconds. Must be >= 128 (ms TTL) or <= 0 (cache never expires,
|
||||
re-fetched only via the refresh button). Default None uses the frontend's built-in behavior."""
|
||||
self.response_key = response_key
|
||||
"""Dot-path to the items array within the response (when not at the top level)."""
|
||||
|
||||
def as_dict(self):
|
||||
return prune_dict({
|
||||
"route": self.route,
|
||||
"item_schema": self.item_schema.as_dict(),
|
||||
"refresh_button": self.refresh_button,
|
||||
"auto_select": self.auto_select,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
"refresh": self.refresh,
|
||||
"response_key": self.response_key,
|
||||
})
|
||||
|
||||
|
||||
class NumberDisplay(str, Enum):
|
||||
number = "number"
|
||||
slider = "slider"
|
||||
@ -493,16 +359,11 @@ class Combo(ComfyTypeIO):
|
||||
upload: UploadType=None,
|
||||
image_folder: FolderType=None,
|
||||
remote: RemoteOptions=None,
|
||||
remote_combo: RemoteComboOptions=None,
|
||||
socketless: bool=None,
|
||||
extra_dict=None,
|
||||
raw_link: bool=None,
|
||||
advanced: bool=None,
|
||||
):
|
||||
if remote is not None and remote_combo is not None:
|
||||
raise ValueError("Combo.Input: pass either 'remote' or 'remote_combo', not both.")
|
||||
if options is not None and remote_combo is not None:
|
||||
raise ValueError("Combo.Input: pass either 'options' or 'remote_combo', not both.")
|
||||
if isinstance(options, type) and issubclass(options, Enum):
|
||||
options = [v.value for v in options]
|
||||
if isinstance(default, Enum):
|
||||
@ -514,7 +375,6 @@ class Combo(ComfyTypeIO):
|
||||
self.upload = upload
|
||||
self.image_folder = image_folder
|
||||
self.remote = remote
|
||||
self.remote_combo = remote_combo
|
||||
self.default: str
|
||||
|
||||
def as_dict(self):
|
||||
@ -525,7 +385,6 @@ class Combo(ComfyTypeIO):
|
||||
**({self.upload.value: True} if self.upload is not None else {}),
|
||||
"image_folder": self.image_folder.value if self.image_folder else None,
|
||||
"remote": self.remote.as_dict() if self.remote else None,
|
||||
"remote_combo": self.remote_combo.as_dict() if self.remote_combo else None,
|
||||
})
|
||||
|
||||
class Output(Output):
|
||||
@ -2362,9 +2221,7 @@ class NodeReplace:
|
||||
__all__ = [
|
||||
"FolderType",
|
||||
"UploadType",
|
||||
"RemoteItemSchema",
|
||||
"RemoteOptions",
|
||||
"RemoteComboOptions",
|
||||
"NumberDisplay",
|
||||
"ControlAfterGenerate",
|
||||
|
||||
|
||||
@ -199,6 +199,9 @@ class FILMNet(nn.Module):
|
||||
def get_dtype(self):
|
||||
return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype
|
||||
|
||||
def memory_used_forward(self, shape, dtype):
|
||||
return 1700 * shape[1] * shape[2] * dtype.itemsize
|
||||
|
||||
def _build_warp_grids(self, H, W, device):
|
||||
"""Pre-compute warp grids for all pyramid levels."""
|
||||
if (H, W) in self._warp_grids:
|
||||
|
||||
@ -74,6 +74,9 @@ class IFNet(nn.Module):
|
||||
def get_dtype(self):
|
||||
return self.encode.cnn0.weight.dtype
|
||||
|
||||
def memory_used_forward(self, shape, dtype):
|
||||
return 300 * shape[1] * shape[2] * dtype.itemsize
|
||||
|
||||
def _build_warp_grids(self, H, W, device):
|
||||
if (H, W) in self._warp_grids:
|
||||
return
|
||||
|
||||
136
comfy_extras/nodes_ar_video.py
Normal file
136
comfy_extras/nodes_ar_video.py
Normal file
@ -0,0 +1,136 @@
|
||||
"""
|
||||
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
||||
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
|
||||
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
|
||||
- ARVideoI2V: image-to-video conditioning for AR models (seeds KV cache with start image)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class EmptyARVideoLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyARVideoLatent",
|
||||
category="latent/video",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=1024, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="LATENT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size) -> io.NodeOutput:
|
||||
lat_t = ((length - 1) // 4) + 1
|
||||
latent = torch.zeros(
|
||||
[batch_size, 16, lat_t, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
|
||||
class SamplerARVideo(io.ComfyNode):
|
||||
"""Sampler for autoregressive video models (Causal Forcing, Self-Forcing).
|
||||
|
||||
All AR-loop parameters are owned by this node so they live in the workflow.
|
||||
Add new widgets here as the AR sampler grows new options.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerARVideo",
|
||||
display_name="Sampler AR Video",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Int.Input(
|
||||
"num_frame_per_block",
|
||||
default=1, min=1, max=64,
|
||||
tooltip="Frames per autoregressive block. 1 = framewise, "
|
||||
"3 = chunkwise. Must match the checkpoint's training mode.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Sampler.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, num_frame_per_block) -> io.NodeOutput:
|
||||
extra_options = {
|
||||
"num_frame_per_block": num_frame_per_block,
|
||||
}
|
||||
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
|
||||
|
||||
|
||||
class ARVideoI2V(io.ComfyNode):
|
||||
"""Image-to-video setup for AR video models (Causal Forcing, Self-Forcing).
|
||||
|
||||
VAE-encodes the start image and stores it in the model's transformer_options
|
||||
so that sample_ar_video can seed the KV cache before denoising.
|
||||
Uses the same T2V model checkpoint -- no separate I2V architecture needed.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ARVideoI2V",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("start_image"),
|
||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=1024, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="MODEL"),
|
||||
io.Latent.Output(display_name="LATENT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, vae, start_image, width, height, length, batch_size) -> io.NodeOutput:
|
||||
start_image = comfy.utils.common_upscale(
|
||||
start_image[:1].movedim(-1, 1), width, height, "bilinear", "center"
|
||||
).movedim(1, -1)
|
||||
|
||||
initial_latent = vae.encode(start_image[:, :, :, :3])
|
||||
|
||||
m = model.clone()
|
||||
to = m.model_options.setdefault("transformer_options", {})
|
||||
ar_cfg = to.setdefault("ar_config", {})
|
||||
ar_cfg["initial_latent"] = initial_latent
|
||||
|
||||
lat_t = ((length - 1) // 4) + 1
|
||||
latent = torch.zeros(
|
||||
[batch_size, 16, lat_t, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
return io.NodeOutput(m, {"samples": latent})
|
||||
|
||||
|
||||
class ARVideoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
EmptyARVideoLatent,
|
||||
SamplerARVideo,
|
||||
ARVideoI2V,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ARVideoExtension:
|
||||
return ARVideoExtension()
|
||||
@ -202,14 +202,11 @@ class JoinImageWithAlpha(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
||||
batch_size = min(len(image), len(alpha))
|
||||
out_images = []
|
||||
|
||||
batch_size = max(len(image), len(alpha))
|
||||
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
||||
for i in range(batch_size):
|
||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||
|
||||
return io.NodeOutput(torch.stack(out_images))
|
||||
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
|
||||
image = comfy.utils.repeat_to_batch_size(image, batch_size)
|
||||
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
|
||||
|
||||
|
||||
class CompositingExtension(ComfyExtension):
|
||||
|
||||
@ -37,7 +37,7 @@ class FrameInterpolationModelLoader(io.ComfyNode):
|
||||
model = cls._detect_and_load(sd)
|
||||
dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32
|
||||
model.eval().to(dtype)
|
||||
patcher = comfy.model_patcher.ModelPatcher(
|
||||
patcher = comfy.model_patcher.CoreModelPatcher(
|
||||
model,
|
||||
load_device=model_management.get_torch_device(),
|
||||
offload_device=model_management.unet_offload_device(),
|
||||
@ -98,16 +98,13 @@ class FrameInterpolate(io.ComfyNode):
|
||||
if num_frames < 2 or multiplier < 2:
|
||||
return io.NodeOutput(images)
|
||||
|
||||
model_management.load_model_gpu(interp_model)
|
||||
device = interp_model.load_device
|
||||
dtype = interp_model.model_dtype()
|
||||
inference_model = interp_model.model
|
||||
|
||||
# Free VRAM for inference activations (model weights + ~20x a single frame's worth)
|
||||
H, W = images.shape[1], images.shape[2]
|
||||
activation_mem = H * W * 3 * images.element_size() * 20
|
||||
model_management.free_memory(activation_mem, device)
|
||||
activation_mem = inference_model.memory_used_forward(images.shape, dtype)
|
||||
model_management.load_models_gpu([interp_model], memory_required=activation_mem)
|
||||
align = getattr(inference_model, "pad_align", 1)
|
||||
H, W = images.shape[1], images.shape[2]
|
||||
|
||||
# Prepare a single padded frame on device for determining output dimensions
|
||||
def prepare_frame(idx):
|
||||
|
||||
@ -666,12 +666,13 @@ class ColorTransfer(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ColorTransfer",
|
||||
display_name="Color Transfer",
|
||||
category="image/postprocessing",
|
||||
description="Match the colors of one image to another using various algorithms.",
|
||||
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
|
||||
inputs=[
|
||||
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
|
||||
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
|
||||
io.Image.Input("image_ref", tooltip="Reference image(s) to match colors to."),
|
||||
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
|
||||
io.DynamicCombo.Input("source_stats",
|
||||
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
|
||||
|
||||
@ -49,7 +49,7 @@ class Int(io.ComfyNode):
|
||||
display_name="Int",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
|
||||
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
|
||||
],
|
||||
outputs=[io.Int.Output()],
|
||||
)
|
||||
|
||||
@ -1016,10 +1016,6 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||
|
||||
if isinstance(input_type, list) or input_type == io.Combo.io_type:
|
||||
if input_type == io.Combo.io_type:
|
||||
# Skip validation for combos with remote options — options
|
||||
# are fetched client-side and not available on the server.
|
||||
if extra_info.get("remote_combo"):
|
||||
continue
|
||||
combo_options = extra_info.get("options", [])
|
||||
else:
|
||||
combo_options = input_type
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
#config for a1111 ui
|
||||
#all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed
|
||||
|
||||
#a111:
|
||||
#a1111:
|
||||
# base_path: path/to/stable-diffusion-webui/
|
||||
# checkpoints: models/Stable-diffusion
|
||||
# configs: models/Stable-diffusion
|
||||
|
||||
67
nodes.py
67
nodes.py
@ -1754,57 +1754,49 @@ class LoadImage:
|
||||
|
||||
return True
|
||||
|
||||
class LoadImageMask:
|
||||
|
||||
class LoadImageMask(LoadImage):
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
||||
|
||||
_color_channels = ["alpha", "red", "green", "blue"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(files), {"image_upload": True}),
|
||||
"channel": (s._color_channels, ), }
|
||||
}
|
||||
types = super().INPUT_TYPES()
|
||||
return {
|
||||
"required": {
|
||||
**types["required"],
|
||||
"channel": (s._color_channels, )
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image, channel):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = node_helpers.pillow(Image.open, image_path)
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
if i.getbands() != ("R", "G", "B", "A"):
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
i = i.convert("RGBA")
|
||||
mask = None
|
||||
FUNCTION = "load_image_mask"
|
||||
|
||||
def load_image_mask(self, image, channel):
|
||||
image_tensor, mask_tensor = super().load_image(image)
|
||||
c = channel[0].upper()
|
||||
if c in i.getbands():
|
||||
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
|
||||
mask = torch.from_numpy(mask)
|
||||
if c == 'A':
|
||||
mask = 1. - mask
|
||||
|
||||
if c == 'A':
|
||||
return (mask_tensor,)
|
||||
|
||||
channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0)
|
||||
|
||||
if channel_idx < image_tensor.shape[-1]:
|
||||
return (image_tensor[..., channel_idx].clone(),)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
return (mask.unsqueeze(0),)
|
||||
empty_mask = torch.zeros(
|
||||
image_tensor.shape[:-1],
|
||||
dtype=image_tensor.dtype,
|
||||
device=image_tensor.device
|
||||
)
|
||||
return (empty_mask,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image, channel):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image):
|
||||
if not folder_paths.exists_annotated_filepath(image):
|
||||
return "Invalid image file: {}".format(image)
|
||||
|
||||
return True
|
||||
return super().IS_CHANGED(image)
|
||||
|
||||
|
||||
class LoadImageOutput(LoadImage):
|
||||
@ -2419,6 +2411,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_nop.py",
|
||||
"nodes_kandinsky5.py",
|
||||
"nodes_wanmove.py",
|
||||
"nodes_ar_video.py",
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
"nodes_glsl.py",
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import errno
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
@ -1245,7 +1246,13 @@ class PromptServer():
|
||||
address = addr[0]
|
||||
port = addr[1]
|
||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||
await site.start()
|
||||
try:
|
||||
await site.start()
|
||||
except OSError as e:
|
||||
if e.errno == errno.EADDRINUSE:
|
||||
logging.error(f"Port {port} is already in use on address {address}. Please close the other application or use a different port with --port.")
|
||||
raise SystemExit(1)
|
||||
raise
|
||||
|
||||
if not hasattr(self, 'address'):
|
||||
self.address = address #TODO: remove this
|
||||
|
||||
@ -1,139 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from comfy_api.latest._io import (
|
||||
Combo,
|
||||
RemoteComboOptions,
|
||||
RemoteItemSchema,
|
||||
RemoteOptions,
|
||||
)
|
||||
|
||||
|
||||
def _schema(**overrides):
|
||||
defaults = dict(value_field="id", label_field="name")
|
||||
return RemoteItemSchema(**{**defaults, **overrides})
|
||||
|
||||
|
||||
def _combo(**overrides):
|
||||
defaults = dict(route="/proxy/foo", item_schema=_schema())
|
||||
return RemoteComboOptions(**{**defaults, **overrides})
|
||||
|
||||
|
||||
def test_item_schema_defaults_accepted():
|
||||
d = _schema().as_dict()
|
||||
assert d == {"value_field": "id", "label_field": "name", "preview_type": "image"}
|
||||
|
||||
|
||||
def test_item_schema_full_config_accepted():
|
||||
d = _schema(
|
||||
preview_url_field="preview",
|
||||
preview_type="audio",
|
||||
description_field="desc",
|
||||
search_fields=["first", "last", "profile.email"],
|
||||
).as_dict()
|
||||
assert d["preview_type"] == "audio"
|
||||
assert d["search_fields"] == ["first", "last", "profile.email"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad_fields",
|
||||
[
|
||||
["{first} {last}"],
|
||||
["name", "{age}"],
|
||||
["leading{"],
|
||||
["trailing}"],
|
||||
],
|
||||
)
|
||||
def test_item_schema_rejects_template_strings_in_search_fields(bad_fields):
|
||||
with pytest.raises(ValueError, match="search_fields"):
|
||||
_schema(search_fields=bad_fields)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bad_preview_type", ["middle", "IMAGE", "", "gif"])
|
||||
def test_item_schema_rejects_unknown_preview_type(bad_preview_type):
|
||||
with pytest.raises(ValueError, match="preview_type"):
|
||||
_schema(preview_type=bad_preview_type)
|
||||
|
||||
|
||||
def test_combo_options_minimal_accepted():
|
||||
d = _combo().as_dict()
|
||||
assert d["route"] == "/proxy/foo"
|
||||
assert d["refresh_button"] is True
|
||||
assert "item_schema" in d
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"route",
|
||||
[
|
||||
"/proxy/foo",
|
||||
"/voices",
|
||||
],
|
||||
)
|
||||
def test_combo_options_accepts_valid_routes(route):
|
||||
_combo(route=route)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"route",
|
||||
[
|
||||
"",
|
||||
"api.example.com/voices",
|
||||
"voices",
|
||||
"ftp-no-scheme",
|
||||
"http://localhost:9000/voices",
|
||||
"https://api.example.com/v1/voices",
|
||||
],
|
||||
)
|
||||
def test_combo_options_rejects_non_relative_routes(route):
|
||||
with pytest.raises(ValueError, match="'route'"):
|
||||
_combo(route=route)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bad_auto_select", ["middle", "FIRST", "", "firstlast"])
|
||||
def test_combo_options_rejects_unknown_auto_select(bad_auto_select):
|
||||
with pytest.raises(ValueError, match="auto_select"):
|
||||
_combo(auto_select=bad_auto_select)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bad_refresh", [1, 127])
|
||||
def test_combo_options_refresh_in_forbidden_range_rejected(bad_refresh):
|
||||
with pytest.raises(ValueError, match="refresh"):
|
||||
_combo(refresh=bad_refresh)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ok_refresh", [0, -1, 128])
|
||||
def test_combo_options_refresh_valid_values_accepted(ok_refresh):
|
||||
_combo(refresh=ok_refresh)
|
||||
|
||||
|
||||
def test_combo_options_timeout_negative_rejected():
|
||||
with pytest.raises(ValueError, match="timeout"):
|
||||
_combo(timeout=-1)
|
||||
|
||||
|
||||
def test_combo_options_max_retries_negative_rejected():
|
||||
with pytest.raises(ValueError, match="max_retries"):
|
||||
_combo(max_retries=-1)
|
||||
|
||||
|
||||
def test_combo_options_as_dict_prunes_none_fields():
|
||||
d = _combo().as_dict()
|
||||
for pruned in ("response_key", "refresh", "timeout", "max_retries", "auto_select"):
|
||||
assert pruned not in d
|
||||
|
||||
|
||||
def test_combo_input_accepts_remote_combo_alone():
|
||||
Combo.Input("voice", remote_combo=_combo())
|
||||
|
||||
|
||||
def test_combo_input_rejects_remote_plus_remote_combo():
|
||||
with pytest.raises(ValueError, match="remote.*remote_combo"):
|
||||
Combo.Input(
|
||||
"voice",
|
||||
remote=RemoteOptions(route="/r", refresh_button=True),
|
||||
remote_combo=_combo(),
|
||||
)
|
||||
|
||||
|
||||
def test_combo_input_rejects_options_plus_remote_combo():
|
||||
with pytest.raises(ValueError, match="options.*remote_combo"):
|
||||
Combo.Input("voice", options=["a", "b"], remote_combo=_combo())
|
||||
Reference in New Issue
Block a user