Compare commits

..

4 Commits

14 changed files with 127 additions and 613 deletions

View File

@ -229,7 +229,7 @@ Python 3.14 works but some custom nodes may have issues. The free threaded varia
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
torch 2.5 is minimally supported but using a newer version is extremely recommended. Some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. If your pytorch is more than 6 months old, please update it.
### Instructions:

View File

@ -217,10 +217,7 @@ class AceStepAttention(nn.Module):
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
n_rep = self.num_heads // self.num_kv_heads
if n_rep > 1:
key_states = key_states.repeat_interleave(n_rep, dim=1)
value_states = value_states.repeat_interleave(n_rep, dim=1)
gqa_kwargs = {"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}
attn_bias = None
if self.sliding_window is not None and not self.is_cross_attention:
@ -244,7 +241,7 @@ class AceStepAttention(nn.Module):
else:
attn_bias = window_bias
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False)
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False, **gqa_kwargs)
attn_output = self.o_proj(attn_output)
return attn_output

View File

@ -425,19 +425,16 @@ class Attention(nn.Module):
if n == 1 and causal:
causal = False
if h != kv_h:
# Repeat interleave kv_heads to match q_heads
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
gqa_kwargs = {"enable_gqa": True} if h != kv_h else {}
if self.differential:
q, q_diff = q.unbind(dim=1)
k, k_diff = k.unbind(dim=1)
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options, **gqa_kwargs)
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options, **gqa_kwargs)
out = out - out_diff
else:
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options, **gqa_kwargs)
out = self.to_out(out)

View File

@ -74,11 +74,8 @@ class BooguDoubleStreamProcessor(nn.Module):
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if attn.kv_heads < attn.heads:
key = key.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
value = value.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
gqa_kwargs = {"enable_gqa": True} if attn.kv_heads < attn.heads else {}
hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options, **gqa_kwargs)
# Split back to instruction/image, apply per-stream output projections, recombine.
instruct_hidden_states = self.instruct_out(hidden_states[:, :L_instruct])

View File

@ -1,5 +1,6 @@
import math
import sys
import inspect
import torch
import torch.nn.functional as F
@ -14,16 +15,16 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
if model_management.xformers_enabled():
import xformers
import xformers.ops
SAGE_ATTENTION_IS_AVAILABLE = False
SAGE_ATTENTION_SUPPORTS_MASK = False
try:
from sageattention import sageattn
SAGE_ATTENTION_IS_AVAILABLE = True
SAGE_ATTENTION_SUPPORTS_MASK = "attn_mask" in inspect.signature(sageattn).parameters
except ImportError as e:
if model_management.sage_attention_enabled():
if e.name == "sageattention":
@ -89,6 +90,44 @@ def default(val, d):
return val
return d
def _gqa_repeat_factor(query_heads, key_heads, value_heads):
if key_heads != value_heads:
raise ValueError(f"Key/value head count mismatch for GQA: {key_heads} != {value_heads}")
if query_heads == key_heads:
return 1
if query_heads % key_heads != 0:
raise ValueError(f"Query heads must be divisible by key/value heads for GQA: {query_heads} vs {key_heads}")
return query_heads // key_heads
def _repeat_kv_for_gqa(k, v, query_heads, head_dim):
n_rep = _gqa_repeat_factor(query_heads, k.shape[head_dim], v.shape[head_dim])
if n_rep > 1:
k = k.repeat_interleave(n_rep, dim=head_dim)
v = v.repeat_interleave(n_rep, dim=head_dim)
return k, v
def _heads_from_dim(tensor, dim_head, name):
inner_dim = tensor.shape[-1]
if inner_dim % dim_head != 0:
raise ValueError(f"{name} inner dimension {inner_dim} is not divisible by head dimension {dim_head}")
return inner_dim // dim_head
def _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, enable_gqa=False, expand_kv=True):
q = q.unsqueeze(3).reshape(b, -1, heads, dim_head)
if enable_gqa:
key_heads = _heads_from_dim(k, dim_head, "Key")
value_heads = _heads_from_dim(v, dim_head, "Value")
else:
key_heads = heads
value_heads = heads
k = k.unsqueeze(3).reshape(b, -1, key_heads, dim_head)
v = v.unsqueeze(3).reshape(b, -1, value_heads, dim_head)
if enable_gqa:
_gqa_repeat_factor(heads, key_heads, value_heads)
if expand_kv:
k, v = _repeat_kv_for_gqa(k, v, heads, -2)
return q, k, v
# feedforward
class GEGLU(nn.Module):
@ -152,28 +191,19 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
n_rep = q.shape[-3] // k.shape[-3]
k = k.repeat_interleave(n_rep, dim=-3)
v = v.repeat_interleave(n_rep, dim=-3)
scale = kwargs.get("scale", dim_head ** -0.5)
h = heads
if skip_reshape:
q, k, v = map(
if kwargs.get("enable_gqa", False):
k, v = _repeat_kv_for_gqa(k, v, q.shape[-3], -3)
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
else:
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
q, k, v = map(lambda t: t.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head).contiguous(), (q, k, v))
# force cast to fp32 to avoid overflowing
if attn_precision == torch.float32:
@ -231,13 +261,16 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
query = query * (kwargs["scale"] * dim_head ** 0.5)
if skip_reshape:
if kwargs.get("enable_gqa", False):
key, value = _repeat_kv_for_gqa(key, value, query.shape[-3], -3)
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
else:
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
query, key, value = _reshape_qkv_to_heads(query, key, value, b, heads, dim_head, kwargs.get("enable_gqa", False))
query = query.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
value = value.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
key = key.permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
dtype = query.dtype
@ -304,19 +337,15 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
scale = kwargs.get("scale", dim_head ** -0.5)
if skip_reshape:
q, k, v = map(
if kwargs.get("enable_gqa", False):
k, v = _repeat_kv_for_gqa(k, v, q.shape[-3], -3)
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
else:
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
q, k, v = map(lambda t: t.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head).contiguous(), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
@ -438,7 +467,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
if skip_reshape:
# b h k d -> b k h d
@ -446,13 +475,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
lambda t: t.permute(0, 2, 1, 3),
(q, k, v),
)
if kwargs.get("enable_gqa", False):
k, v = _repeat_kv_for_gqa(k, v, q.shape[-2], -2)
# actually do the reshaping
else:
dim_head //= heads
q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v),
)
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
if mask is not None:
# add a singleton batch dimension
@ -474,7 +502,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
mask = mask_out[..., :mask.shape[-1]]
mask = mask.expand(b, heads, -1, -1)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask, scale=kwargs.get("scale", None))
if skip_output_reshape:
out = out.permute(0, 2, 1, 3)
@ -498,10 +526,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False), expand_kv=False)
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
if mask is not None:
# add a batch dimension if there isn't already one
@ -511,9 +537,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3:
mask = mask.unsqueeze(1)
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
sdpa_keys = ("scale", "enable_gqa")
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
if SDP_BATCH_LIMIT >= b:
@ -541,20 +565,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if kwargs.get("low_precision_attention", True) is False:
if kwargs.get("low_precision_attention", True) is False or (mask is not None and not SAGE_ATTENTION_SUPPORTS_MASK):
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
if kwargs.get("enable_gqa", False):
k, v = _repeat_kv_for_gqa(k, v, q.shape[-3], -3)
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
tensor_layout = "NHD"
if mask is not None:
@ -565,8 +588,12 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
if mask.ndim == 3:
mask = mask.unsqueeze(1)
sage_kwargs = {"is_causal": False, "tensor_layout": tensor_layout, "sm_scale": kwargs.get("scale", None), "smooth_k": False}
if mask is not None:
sage_kwargs["attn_mask"] = mask
try:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
out = sageattn(q, k, v, **sage_kwargs)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
exception_fallback = True
@ -616,7 +643,6 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
skip_output_reshape=skip_output_reshape,
**kwargs
)
q_s, k_s, v_s = q, k, v
N = q.shape[2]
dim_head = D
else:
@ -642,11 +668,15 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
**kwargs
)
if not skip_reshape:
q_s, k_s, v_s = map(
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
(q, k, v),
)
if skip_reshape:
q_s = q
if kwargs.get("enable_gqa", False):
k_s, v_s = _repeat_kv_for_gqa(k, v, H, -3)
else:
k_s, v_s = k, v
else:
q_s, k_s, v_s = _reshape_qkv_to_heads(q, k, v, B, heads, dim_head, kwargs.get("enable_gqa", False))
q_s, k_s, v_s = map(lambda t: t.permute(0, 2, 1, 3).contiguous(), (q_s, k_s, v_s))
B, H, L, D = q_s.shape
try:
@ -662,7 +692,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)
@ -681,19 +711,20 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
dropout_p: float = 0.0, causal: bool = False, softmax_scale: float = -1.0) -> torch.Tensor:
softmax_scale_arg = None if softmax_scale == -1.0 else softmax_scale
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal, softmax_scale=softmax_scale_arg)
@flash_attn_wrapper.register_fake
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False, softmax_scale=-1.0):
# Output shape is the same as q
return q.new_empty(q.shape)
except AttributeError as error:
FLASH_ATTN_ERROR = error
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
dropout_p: float = 0.0, causal: bool = False, softmax_scale: float = -1.0) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
@wrap_attn
@ -703,10 +734,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False), expand_kv=False)
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
if mask is not None:
# add a batch dimension if there isn't already one
@ -725,10 +754,16 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
v.transpose(1, 2),
dropout_p=0.0,
causal=False,
softmax_scale=kwargs.get("scale", -1.0),
).transpose(1, 2)
except Exception as e:
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
sdpa_extra = {}
if kwargs.get("enable_gqa", False):
sdpa_extra["enable_gqa"] = True
if "scale" in kwargs:
sdpa_extra["scale"] = kwargs["scale"]
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -1209,5 +1244,3 @@ class SpatialVideoTransformer(SpatialTransformer):
x = self.proj_out(x)
out = x + x_in
return out

View File

@ -141,11 +141,8 @@ class Attention(nn.Module):
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.kv_heads < self.heads:
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
gqa_kwargs = {"enable_gqa": True} if self.kv_heads < self.heads else {}
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options, **gqa_kwargs)
hidden_states = self.to_out[0](hidden_states)
return hidden_states

View File

@ -12,7 +12,7 @@ import torch.nn.functional as F
import comfy.ops
from comfy import sd1_clip
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.text_encoders.llama import RMSNorm, apply_rope
@ -110,10 +110,6 @@ def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sin
putting the sink logit in the mask at that column.
"""
if num_kv_groups > 1 and not TORCH_HAS_GQA:
k = k.repeat_interleave(num_kv_groups, dim=1)
v = v.repeat_interleave(num_kv_groups, dim=1)
B, _, S_q, D = q.shape
H_kv = k.shape[1]
S_kv = k.shape[-2]

View File

@ -550,10 +550,8 @@ class Attention(nn.Module):
xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
gqa_kwargs = {"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, **gqa_kwargs)
return self.o_proj(output), present_key_value
class MLP(nn.Module):

View File

@ -366,12 +366,8 @@ class GatedAttention(nn.Module):
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
present_key_value = (xk, xv, index + num_tokens)
# Expand KV heads for GQA
if self.num_heads != self.num_kv_heads:
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
gqa_kwargs = {"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, **gqa_kwargs)
output = output * gate.sigmoid()
return self.o_proj(output), present_key_value

View File

@ -1261,155 +1261,6 @@ class DynamicSlot(ComfyTypeI):
out_dict[input_type][finalized_id] = value
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
@comfytype(io_type="COMFY_DYNAMICGROUP_V3")
class DynamicGroup(ComfyTypeI):
"""A repeatable group of widget inputs (e.g. lora_name + strength stacked into N rows).
At execution time the node receives a ``list[dict]`` where each element is a row.
Example::
io.DynamicGroup.Input(
"loras",
template=[
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras")),
io.Float.Input("strength", default=1.0, min=-100, max=100, step=0.01),
],
min=0,
max=50,
)
# execute receives: loras: list[dict] = [{"lora_name": "x.safetensors", "strength": 1.0}, ...]
"""
Type = list[dict[str, Any]]
_MaxRows = 100
class Input(DynamicInput):
def __init__(
self,
id: str,
template: list["Input"],
min: int = 0,
max: int = 50,
display_name: str = None,
optional: bool = False,
tooltip: str = None,
lazy: bool = None,
extra_dict=None,
group_name: str = "Group",
):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
assert len(template) > 0, "DynamicGroup template must have at least one field."
for t in template:
assert isinstance(t, WidgetInput), (
f"DynamicGroup template field '{t.id}' must be a WidgetInput subclass "
f"(Combo, Float, Int, String, Boolean, Color). Got {type(t).__name__}."
)
assert not isinstance(t, DynamicInput), (
f"DynamicGroup template field '{t.id}' must not be a DynamicInput. "
"Nesting dynamic inputs inside DynamicGroup is not supported."
)
field_ids = [t.id for t in template]
assert len(field_ids) == len(set(field_ids)), (
f"DynamicGroup template field ids must be unique within a row. Got: {field_ids}"
)
# Reject "." in group id and template field ids: slot_id encoding uses "." as a
# delimiter (<group_id>.<row>.<field_id>), so any "." in these names would cause
# path.split(".") to produce the wrong number of segments during decoding.
assert "." not in id, (
f"DynamicGroup id must not contain '.'. Got: '{id}'"
)
for t in template:
assert "." not in t.id, (
f"DynamicGroup template field id must not contain '.'. Got: '{t.id}'"
)
assert min >= 0, "DynamicGroup min must be >= 0."
assert max >= 1, "DynamicGroup max must be >= 1."
assert max <= DynamicGroup._MaxRows, f"DynamicGroup max must be <= {DynamicGroup._MaxRows}."
assert min <= max, "DynamicGroup min must be <= max."
self.template = template
self.min = min
self.max = max
self.group_name = group_name
def get_all(self) -> list["Input"]:
return [self] + list(self.template)
def as_dict(self):
return super().as_dict() | prune_dict({
"template": create_input_dict_v1(self.template),
"min": self.min,
"max": self.max,
"group_name": self.group_name,
})
def validate(self):
for t in self.template:
t.validate()
@staticmethod
def _expand_schema_for_dynamic(
out_dict: dict[str, Any],
live_inputs: dict[str, Any],
value: tuple[str, dict[str, Any]],
input_type: str,
curr_prefix: list[str] | None,
):
info = value[1]
min_rows: int = info.get("min", 0)
max_rows: int = info.get("max", DynamicGroup._MaxRows)
template: dict[str, Any] = info.get("template", {})
# Collect all template field specs across required/optional sections
field_specs: list[tuple[str, tuple[str, dict[str, Any]], bool]] = []
for field_required_key in ("required", "optional"):
section = template.get(field_required_key, {})
is_required_field = field_required_key == "required"
for field_id, field_value in section.items():
field_specs.append((field_id, field_value, is_required_field))
# Determine how many rows are currently present by scanning live_inputs
finalized_prefix = finalize_prefix(curr_prefix)
present_rows = 0
for live_key in live_inputs:
# Keys look like "<prefix>.<row>.<field_id>"
if live_key.startswith(finalized_prefix + "."):
remainder = live_key[len(finalized_prefix) + 1:]
parts = remainder.split(".", 1)
if len(parts) >= 1:
try:
row_idx = int(parts[0])
present_rows = max(present_rows, row_idx + 1)
except ValueError:
pass
if present_rows > max_rows:
raise ValueError(
f"DynamicGroup input '{finalized_prefix}' received {present_rows} rows but max is {max_rows}."
)
row_count = max(min_rows, present_rows)
for row in range(row_count):
for field_id, field_value, is_required_field in field_specs:
slot_id = f"{finalized_prefix}.{row}.{field_id}"
if row < min_rows and is_required_field:
out_dict["required"][slot_id] = field_value
else:
out_dict["optional"][slot_id] = field_value
# Register into dynamic_paths so build_nested_inputs places value at the right path
out_dict["dynamic_paths"][slot_id] = slot_id
# Track the list root path so build_nested_inputs can convert the index dict to a list
out_dict.setdefault("list_paths", set()).add(finalized_prefix)
# Handle the empty case (0 rows) emit an empty-list default for the parent.
# This must only fire when there are genuinely no rows; otherwise the parent
# path would clobber the per-row dict built from the slot ids above.
if row_count == 0:
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_LIST
@comfytype(io_type="IMAGECOMPARE")
class ImageCompare(ComfyTypeI):
Type = dict
@ -1567,8 +1418,6 @@ def setup_dynamic_input_funcs():
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
# DynamicSlot.Input
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
# DynamicGroup.Input
register_dynamic_input_func(DynamicGroup.io_type, DynamicGroup._expand_schema_for_dynamic)
if len(DYNAMIC_INPUT_LOOKUP) == 0:
setup_dynamic_input_funcs()
@ -1580,8 +1429,6 @@ class V3Data(TypedDict):
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
dynamic_paths_default_value: dict[str, Any]
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
list_paths: set[str]
'Set of top-level keys whose index-keyed dict values should be converted to a sorted list[dict] after build_nested_inputs runs.'
create_dynamic_tuple: bool
'When True, the value of the dynamic input will be in the format (value, path_key).'
@ -1923,7 +1770,6 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
"optional": {},
"dynamic_paths": {},
"dynamic_paths_default_value": {},
"list_paths": set(),
}
d = d.copy()
# ignore hidden for parsing
@ -1939,10 +1785,6 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
# list_paths: keys whose nested dict should be post-converted to a sorted list[dict]
list_paths = out_dict.pop("list_paths", None)
if list_paths:
v3_data["list_paths"] = list_paths
return out_dict, hidden, v3_data
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
@ -1978,12 +1820,10 @@ def add_to_dict_v1(i: Input, d: dict):
class DynamicPathsDefaultValue:
EMPTY_DICT = "empty_dict"
EMPTY_LIST = "empty_list"
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
paths = v3_data.get("dynamic_paths", None)
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
list_paths: set[str] = v3_data.get("list_paths", set()) or set()
if paths is None:
return values
values = values.copy()
@ -2006,8 +1846,6 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
default_option = default_value_dict.get(key, None)
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
value = {}
elif default_option == DynamicPathsDefaultValue.EMPTY_LIST:
value = []
if create_tuple:
value = (value, key)
current[p] = value
@ -2015,34 +1853,6 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
current = current.setdefault(p, {})
values.update(result)
# Post-pass: convert index-keyed dicts to sorted lists for io.DynamicGroup fields
for list_path in list_paths:
parts = list_path.split(".")
# Navigate to the parent container, then convert the leaf
container = values
for part in parts[:-1]:
if not isinstance(container, dict) or part not in container:
container = None
break
container = container[part]
if container is None:
continue
leaf_key = parts[-1]
leaf = container.get(leaf_key, None)
if isinstance(leaf, dict):
try:
sorted_rows = [leaf[k] for k in sorted(leaf.keys(), key=int)]
container[leaf_key] = sorted_rows
except (ValueError, TypeError):
# Keys are not all integers; leave as-is
pass
elif isinstance(leaf, list):
# Already a list (e.g. the EMPTY_LIST default was applied above)
pass
elif leaf is None:
container[leaf_key] = []
return values
@ -2607,9 +2417,7 @@ __all__ = [
# Dynamic Types
"MatchType",
"DynamicCombo",
"DynamicSlot",
"Autogrow",
"DynamicGroup",
# Other classes
"HiddenHolder",
"Hidden",

View File

@ -16,23 +16,30 @@ class ColorToRGBInt(io.ComfyNode):
],
outputs=[
io.Int.Output(display_name="rgb_int"),
io.Color.Output(display_name="hex")
io.Color.Output(display_name="hex"),
io.Float.Output(display_name="alpha"),
],
)
@classmethod
def execute(cls, color: str) -> io.NodeOutput:
# expect format #RRGGBB
if len(color) != 7 or color[0] != "#":
raise ValueError("Color must be in format #RRGGBB")
# expect format #RRGGBB or #RRGGBBAA
if len(color) not in (7, 9) or color[0] != "#":
raise ValueError("Color must be in format #RRGGBB or #RRGGBBAA")
try:
int(color[1:], 16)
except ValueError:
raise ValueError("Color must be in format #RRGGBB") from None
raise ValueError("Color must be in format #RRGGBB or #RRGGBBAA") from None
alpha = 1.0
if len(color) == 9:
alpha = int(color[7:9], 16) / 255.0
color = color[:7]
r, g, b = hex_to_rgb(color)
rgb_int = r * 256 * 256 + g * 256 + b
return io.NodeOutput(rgb_int, color)
return io.NodeOutput(rgb_int, color, alpha)
class ColorExtension(ComfyExtension):

View File

@ -1,107 +0,0 @@
from __future__ import annotations
from typing_extensions import override
import comfy.sd
import comfy.utils
import folder_paths
from comfy_api.latest import ComfyExtension, io
def _load_lora_file(lora_name: str):
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
return comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
def _lora_template() -> list[io.Input]:
return [
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras"),
tooltip="The name of the LoRA file to apply."),
io.Float.Input("strength", default=1.0, min=-100.0, max=100.0, step=0.01,
tooltip="How strongly to apply this LoRA. 0 = off, negative inverts the effect."),
]
class LoadLoraModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadLoraModel",
display_name="Load LoRA (Model)",
search_aliases=["lora", "load lora", "apply lora", "lora model", "lora stack"],
category="model/loaders",
description="Apply a stack of LoRAs to a diffusion model. Add one row per LoRA; "
"each row picks a LoRA file and its strength.",
inputs=[
io.Model.Input("model", tooltip="The diffusion model the LoRAs will be applied to."),
io.DynamicGroup.Input(
"loras",
template=_lora_template(),
min=1,
max=50,
tooltip="Each row applies one LoRA to the model.",
group_name="LoRA",
),
],
outputs=[io.Model.Output(tooltip="The modified diffusion model.")],
)
@classmethod
def execute(cls, model, loras: list[dict]) -> io.NodeOutput:
for row in loras:
lora_name = row.get("lora_name")
strength = row.get("strength", 1.0)
if not lora_name or lora_name == "none" or strength == 0:
continue
lora, metadata = _load_lora_file(lora_name)
model, _ = comfy.sd.load_lora_for_models(model, None, lora, strength, 0, lora_metadata=metadata)
return io.NodeOutput(model)
class LoadLoraTextEncoder(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadLoraTextEncoder",
display_name="Load LoRA (Text Encoder)",
search_aliases=["lora", "load lora", "apply lora", "clip lora", "lora stack"],
category="model/loaders",
description="Apply a stack of LoRAs to a CLIP text encoder. Add one row per LoRA; "
"each row picks a LoRA file and its strength.",
inputs=[
io.Clip.Input("clip", tooltip="The CLIP text encoder the LoRAs will be applied to."),
io.DynamicGroup.Input(
"loras",
template=_lora_template(),
min=1,
max=50,
tooltip="Each row applies one LoRA to the text encoder.",
group_name="LoRA",
),
],
outputs=[io.Clip.Output(tooltip="The modified CLIP text encoder.")],
)
@classmethod
def execute(cls, clip, loras: list[dict]) -> io.NodeOutput:
for row in loras:
lora_name = row.get("lora_name")
strength = row.get("strength", 1.0)
if not lora_name or lora_name == "none" or strength == 0:
continue
lora, metadata = _load_lora_file(lora_name)
_, clip = comfy.sd.load_lora_for_models(None, clip, lora, 0, strength, lora_metadata=metadata)
return io.NodeOutput(clip)
class LoraStackExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LoadLoraModel,
LoadLoraTextEncoder,
]
async def comfy_entrypoint() -> LoraStackExtension:
return LoraStackExtension()

View File

@ -2502,7 +2502,6 @@ async def init_builtin_extra_nodes():
"nodes_triposplat.py",
"nodes_depth_anything_3.py",
"nodes_seed.py",
"nodes_lora_stack.py",
]
import_failed = []

View File

@ -1,204 +0,0 @@
"""Unit tests for io.DynamicGroup: expansion/reconstruction (0-row and N-row cases)."""
import sys
import types
import pytest
# Stub torch (type-hint only in _io.py; real torch not available in unit-test env)
if "torch" not in sys.modules:
_torch_stub = types.ModuleType("torch")
_torch_stub.Tensor = object # type: ignore[attr-defined]
sys.modules["torch"] = _torch_stub
from comfy_api.latest._io import ( # noqa: E402
DynamicGroup,
Float,
Int,
String,
Boolean,
get_finalized_class_inputs,
build_nested_inputs,
create_input_dict_v1,
setup_dynamic_input_funcs,
)
# Make sure dynamic input funcs are registered (may already be done at import time)
setup_dynamic_input_funcs()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_class_inputs(group_input: DynamicGroup.Input) -> dict:
"""Wrap a DynamicGroup.Input into the required/optional dict structure."""
return create_input_dict_v1([group_input])
def _run(group_input: DynamicGroup.Input, live_values: dict) -> dict:
"""End-to-end helper: expand schema + reconstruct values.
Mirrors the production split in execution.py:
1. get_finalized_class_inputs (schema expansion, line 162)
2. build_nested_inputs (value reconstruction, line 281)
The two steps are separate in production because the engine resolves
linked node outputs between them, but in tests we supply values directly.
"""
class_inputs = _make_class_inputs(group_input)
_, _, v3_data = get_finalized_class_inputs(class_inputs, live_values)
return build_nested_inputs(dict(live_values), v3_data)
# ---------------------------------------------------------------------------
# Schema construction
# ---------------------------------------------------------------------------
class TestDynamicGroupInputConstruction:
def test_basic_construction(self):
inp = DynamicGroup.Input(
"loras",
template=[
Float.Input("strength", default=1.0),
String.Input("name"),
],
min=0,
max=10,
)
assert inp.id == "loras"
assert inp.min == 0
assert inp.max == 10
assert len(inp.template) == 2
def test_get_all_includes_self_and_template(self):
inp = DynamicGroup.Input(
"items",
template=[Float.Input("value")],
)
all_inputs = inp.get_all()
assert all_inputs[0] is inp
assert all_inputs[1].id == "value"
def test_as_dict_has_template_min_max(self):
inp = DynamicGroup.Input(
"items",
template=[Float.Input("val", default=0.5)],
min=1,
max=5,
)
d = inp.as_dict()
assert "template" in d
assert d["min"] == 1
assert d["max"] == 5
def test_duplicate_field_ids_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input(
"bad",
template=[Float.Input("x"), Float.Input("x")],
)
def test_empty_template_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input("bad", template=[])
def test_min_gt_max_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input("bad", template=[Float.Input("x")], min=5, max=3)
def test_max_exceeds_limit_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input("bad", template=[Float.Input("x")], max=101)
def test_dynamic_input_in_template_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input(
"bad",
template=[DynamicGroup.Input("nested", template=[Float.Input("x")])],
)
def test_validate_calls_through(self):
inp = DynamicGroup.Input("items", template=[Float.Input("val", min=-1.0, max=1.0)])
inp.validate() # should not raise
# ---------------------------------------------------------------------------
# 0-row case
# ---------------------------------------------------------------------------
class TestZeroRows:
def test_empty_live_inputs_produces_empty_list(self):
"""With min=0 and no live values, the result should be an empty list."""
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
assert _run(inp, {}).get("loras") == []
def test_min_zero_with_values(self):
"""min=0 but 2 rows of live data."""
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
result = _run(inp, {"loras.0.strength": 0.8, "loras.1.strength": 0.5})
assert result["loras"] == [{"strength": 0.8}, {"strength": 0.5}]
# ---------------------------------------------------------------------------
# N-row case
# ---------------------------------------------------------------------------
class TestNRows:
def test_two_rows_two_fields(self):
"""Two rows with two fields each produce a list[dict]."""
inp = DynamicGroup.Input(
"loras",
template=[String.Input("lora_name"), Float.Input("strength", default=1.0)],
min=0, max=50,
)
result = _run(inp, {
"loras.0.lora_name": "model_a.safetensors", "loras.0.strength": 0.9,
"loras.1.lora_name": "model_b.safetensors", "loras.1.strength": 0.4,
})
assert result["loras"] == [
{"lora_name": "model_a.safetensors", "strength": 0.9},
{"lora_name": "model_b.safetensors", "strength": 0.4},
]
def test_rows_are_sorted_by_index(self):
"""Rows must be in ascending index order even if dict iteration is unordered."""
inp = DynamicGroup.Input("items", template=[Int.Input("v", default=0)], min=0, max=10)
result = _run(inp, {"items.0.v": 10, "items.2.v": 30, "items.1.v": 20})
assert [row["v"] for row in result["items"]] == [10, 20, 30]
def test_min_rows_schema_slots(self):
"""With min=2 and no live data, 2 slots must appear in the expanded schema."""
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
out, _, _ = get_finalized_class_inputs(_make_class_inputs(inp), {})
all_slots = {**out.get("required", {}), **out.get("optional", {})}
assert "items.0.val" in all_slots
assert "items.1.val" in all_slots
def test_min_rows_reconstructs_when_no_values(self):
"""min=2 with NO live values must still yield a 2-element list,
not collapse to [] (regression: parent-path clobber)."""
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
result = _run(inp, {})
assert len(result["items"]) == 2
assert all("val" in row for row in result["items"])
def test_min_rows_reconstructs_with_partial_values(self):
"""min=2 with only the first row's value present still yields 2 rows."""
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
result = _run(inp, {"items.0.val": 0.7})
assert len(result["items"]) == 2
assert result["items"][0]["val"] == 0.7
assert result["items"][1]["val"] is None
def test_list_paths_in_v3_data(self):
"""list_paths must contain the group id so build_nested_inputs knows to convert."""
inp = DynamicGroup.Input("things", template=[Boolean.Input("flag")], min=0, max=5)
_, _, v3_data = get_finalized_class_inputs(_make_class_inputs(inp), {})
assert "things" in v3_data.get("list_paths", set())
def test_no_leftover_flat_keys(self):
"""Flat keys must be consumed; only the reconstructed list remains."""
inp = DynamicGroup.Input("rows", template=[Float.Input("x", default=0.0)], min=0, max=5)
result = _run(inp, {"rows.0.x": 1.0, "rows.1.x": 2.0})
assert "rows.0.x" not in result
assert "rows.1.x" not in result
assert isinstance(result["rows"], list)